Revamp how CAPs are tracked through REQ and ACK/NAK etc

This commit is contained in:
jesopo 2019-05-11 18:22:40 +01:00
parent 6ef7f8374d
commit d291cd5063
5 changed files with 68 additions and 55 deletions

View file

@ -19,14 +19,20 @@ CAPABILITIES = [
utils.irc.Capability(None, "draft/setname") utils.irc.Capability(None, "draft/setname")
] ]
def _match_caps(capabilities): def _match_caps(our_capabilities, offered_capabilities):
matched = [] matched = {}
for capability in CAPABILITIES: for capability in our_capabilities:
available = capability.available(capabilities) available = capability.available(offered_capabilities)
if available: if available:
matched.append(available) matched[available] = capability
return matched return matched
def _caps_offered(server, caps):
blacklist = server.get_setting("blacklisted-caps", [])
for cap_name, cap in caps.items():
if not cap_name in blacklist:
server.capability_queue[cap_name] = cap
def cap(events, event): def cap(events, event):
capabilities = utils.parse.keyvalue(event["args"][-1]) capabilities = utils.parse.keyvalue(event["args"][-1])
subcommand = event["args"][1].lower() subcommand = event["args"][1].lower()
@ -36,20 +42,18 @@ def cap(events, event):
event["server"].cap_started = True event["server"].cap_started = True
event["server"].server_capabilities.update(capabilities) event["server"].server_capabilities.update(capabilities)
if not is_multiline: if not is_multiline:
matched_caps = _match_caps( server_caps = list(event["server"].server_capabilities.keys())
list(event["server"].server_capabilities.keys())) matched_caps = _match_caps(CAPABILITIES, server_caps)
blacklisted_caps = event["server"].get_setting(
"blacklisted-caps", [])
matched_caps = list(
set(matched_caps)-set(blacklisted_caps))
event["server"].queue_capabilities(matched_caps) module_caps = events.on("received.cap.ls").call(
events.on("received.cap.ls").call(
capabilities=event["server"].server_capabilities, capabilities=event["server"].server_capabilities,
server=event["server"]) server=event["server"])
module_caps = list(filter(None, module_caps))
matched_caps.update(_match_caps(module_caps, server_caps))
if event["server"].has_capability_queue(): _caps_offered(event["server"], matched_caps)
if event["server"].capability_queue:
event["server"].send_capability_queue() event["server"].send_capability_queue()
else: else:
event["server"].send_capability_end() event["server"].send_capability_end()
@ -58,12 +62,15 @@ def cap(events, event):
event["server"].server_capabilities.update(capabilities) event["server"].server_capabilities.update(capabilities)
matched_caps = _match_caps(list(capabilities_keys)) matched_caps = _match_caps(list(capabilities_keys))
event["server"].queue_capabilities(matched_caps)
events.on("received.cap.new").call(server=event["server"], module_caps = events.on("received.cap.new").call(
capabilities=capabilities) server=event["server"], capabilities=capabilities)
module_caps = list(filter(None, module_caps))
matched_caps.update(_match_caps(module_caps, capabilities_keys))
if event["server"].has_capability_queue(): _caps_offered(event["server"], matched_caps)
if event["server"].capability_queue:
event["server"].send_capability_queue() event["server"].send_capability_queue()
elif subcommand == "del": elif subcommand == "del":
for capability in capabilities.keys(): for capability in capabilities.keys():
@ -78,11 +85,17 @@ def cap(events, event):
server=event["server"]) server=event["server"])
if subcommand == "ack" or subcommand == "nak": if subcommand == "ack" or subcommand == "nak":
ack = subcommand == "ack"
for capability in capabilities: for capability in capabilities:
event["server"].requested_capabilities.remove(capability) cap_obj = event["server"].capability_queue[capability]
del event["server"].capability_queue[capability]
if ack:
cap_obj.ack()
else:
cap_obj.nak()
if (event["server"].cap_started and if (event["server"].cap_started and
not event["server"].requested_capabilities and not event["server"].capability_queue and
not event["server"].waiting_for_capabilities()): not event["server"].waiting_for_capabilities()):
event["server"].cap_started = False event["server"].cap_started = False
event["server"].send_capability_end() event["server"].send_capability_end()

View file

@ -23,12 +23,12 @@ class Module(ModuleManager.BaseModule):
@utils.hook("received.cap.ls") @utils.hook("received.cap.ls")
def on_cap_ls(self, event): def on_cap_ls(self, event):
if CAP in event["capabilities"]: if CAP in event["capabilities"]:
event["server"].queue_capability(CAP) cap = utils.irc.Capability(CAP)
cap.on_ack(lambda: self._cap_ack(event["server"]))
return cap
@utils.hook("received.cap.ack") def _cap_ack(self, server):
def on_cap_ack(self, event): server.wait_for_capability("resume")
if CAP in event["capabilities"]:
event["server"].wait_for_capability("resume")
@utils.hook("received.resume") @utils.hook("received.resume")
def on_resume(self, event): def on_resume(self, event):

View file

@ -41,22 +41,22 @@ class Module(ModuleManager.BaseModule):
do_sasl = True do_sasl = True
if do_sasl: if do_sasl:
event["server"].queue_capability("sasl") cap = utils.irc.Capability("sasl")
cap.on_ack(lambda: self._sasl_ack(event["server"]))
return cap
@utils.hook("received.cap.ack") def _sasl_ack(self, server):
def on_cap_ack(self, event): sasl = server.get_setting("sasl")
if "sasl" in event["capabilities"]: mechanism = sasl["mechanism"].upper()
sasl = event["server"].get_setting("sasl") if mechanism == "USERPASS":
mechanism = sasl["mechanism"].upper() server_mechanisms = server.server_capabilities["sasl"]
if mechanism == "USERPASS": server_mechanisms = server_mechanisms or [
server_mechanisms = event["server"].server_capabilities["sasl"] USERPASS_MECHANISMS[0]]
server_mechanisms = server_mechanisms or [ mechanism = self._best_userpass_mechanism(server_mechanisms)
USERPASS_MECHANISMS[0]]
mechanism = self._best_userpass_mechanism(server_mechanisms)
event["server"].send_authenticate(mechanism) server.send_authenticate(mechanism)
event["server"].sasl_mechanism = mechanism server.sasl_mechanism = mechanism
event["server"].wait_for_capability("sasl") server.wait_for_capability("sasl")
@utils.hook("received.authenticate") @utils.hook("received.authenticate")
def on_authenticate(self, event): def on_authenticate(self, event):

View file

@ -24,10 +24,10 @@ class Server(IRCObject.Object):
self.realname = None # type: typing.Optional[str] self.realname = None # type: typing.Optional[str]
self.hostname = None # type: typing.Optional[str] self.hostname = None # type: typing.Optional[str]
self._capability_queue = set([]) # type: typing.Set[str] self.capability_queue = {
} # type: typing.Dict[str, utils.irc.Capability]
self._capabilities_waiting = set([]) # type: typing.Set[str] self._capabilities_waiting = set([]) # type: typing.Set[str]
self.agreed_capabilities = set([]) # type: typing.Set[str] self.agreed_capabilities = set([]) # type: typing.Set[str]
self.requested_capabilities = [] # type: typing.List[str]
self.server_capabilities = {} # type: typing.Dict[str, str] self.server_capabilities = {} # type: typing.Dict[str, str]
self.batches = {} # type: typing.Dict[str, IRCLine.ParsedLine] self.batches = {} # type: typing.Dict[str, IRCLine.ParsedLine]
self.cap_started = False self.cap_started = False
@ -269,21 +269,12 @@ class Server(IRCObject.Object):
def send_capibility_ls(self) -> IRCLine.SentLine: def send_capibility_ls(self) -> IRCLine.SentLine:
return self.send(utils.irc.protocol.capability_ls()) return self.send(utils.irc.protocol.capability_ls())
def queue_capability(self, capability: str):
self._capability_queue.add(capability)
def queue_capabilities(self, capabilities: typing.List[str]):
self._capability_queue.update(capabilities)
def send_capability_queue(self): def send_capability_queue(self):
if self.has_capability_queue(): capability_queue = [cap for cap in self.capability_queue.keys()]
capability_queue = list(self._capability_queue)
self._capability_queue.clear()
for i in range(0, len(capability_queue), 10): for i in range(0, len(capability_queue), 10):
capability_batch = capability_queue[i:i+10] capability_batch = capability_queue[i:i+10]
self.requested_capabilities += capability_batch self.send_capability_request(" ".join(capability_batch))
self.send_capability_request(" ".join(capability_batch))
def has_capability_queue(self):
return bool(len(self._capability_queue))
def send_capability_request(self, capability: str) -> IRCLine.SentLine: def send_capability_request(self, capability: str) -> IRCLine.SentLine:
return self.send(utils.irc.protocol.capability_request(capability)) return self.send(utils.irc.protocol.capability_request(capability))
def send_capability_end(self) -> IRCLine.SentLine: def send_capability_end(self) -> IRCLine.SentLine:

View file

@ -281,8 +281,17 @@ class Capability(object):
self._caps = set([name, draft_name]) self._caps = set([name, draft_name])
self._name = name self._name = name
self._draft_name = draft_name self._draft_name = draft_name
self._on_ack_callbacks = []
def available(self, capabilities: typing.List[str]) -> str: def available(self, capabilities: typing.List[str]) -> str:
match = list(set(capabilities)&self._caps) match = list(set(capabilities)&self._caps)
return match[0] if match else None return match[0] if match else None
def enabled(self, capability: str) -> bool: def enabled(self, capability: str) -> bool:
return capability in self._caps return capability in self._caps
def on_ack(self, callback: typing.Callable[[], None]):
self._on_ack_callbacks.append(callback)
def ack(self):
for callback in self._on_ack_callbacks:
callback()
def nak(self):
pass