Revamp how CAPs are tracked through REQ and ACK/NAK etc

This commit is contained in:
jesopo 2019-05-11 18:22:40 +01:00
parent 6ef7f8374d
commit d291cd5063
5 changed files with 68 additions and 55 deletions

View file

@ -19,14 +19,20 @@ CAPABILITIES = [
utils.irc.Capability(None, "draft/setname")
]
def _match_caps(capabilities):
matched = []
for capability in CAPABILITIES:
available = capability.available(capabilities)
def _match_caps(our_capabilities, offered_capabilities):
matched = {}
for capability in our_capabilities:
available = capability.available(offered_capabilities)
if available:
matched.append(available)
matched[available] = capability
return matched
def _caps_offered(server, caps):
blacklist = server.get_setting("blacklisted-caps", [])
for cap_name, cap in caps.items():
if not cap_name in blacklist:
server.capability_queue[cap_name] = cap
def cap(events, event):
capabilities = utils.parse.keyvalue(event["args"][-1])
subcommand = event["args"][1].lower()
@ -36,20 +42,18 @@ def cap(events, event):
event["server"].cap_started = True
event["server"].server_capabilities.update(capabilities)
if not is_multiline:
matched_caps = _match_caps(
list(event["server"].server_capabilities.keys()))
blacklisted_caps = event["server"].get_setting(
"blacklisted-caps", [])
matched_caps = list(
set(matched_caps)-set(blacklisted_caps))
server_caps = list(event["server"].server_capabilities.keys())
matched_caps = _match_caps(CAPABILITIES, server_caps)
event["server"].queue_capabilities(matched_caps)
events.on("received.cap.ls").call(
module_caps = events.on("received.cap.ls").call(
capabilities=event["server"].server_capabilities,
server=event["server"])
module_caps = list(filter(None, module_caps))
matched_caps.update(_match_caps(module_caps, server_caps))
if event["server"].has_capability_queue():
_caps_offered(event["server"], matched_caps)
if event["server"].capability_queue:
event["server"].send_capability_queue()
else:
event["server"].send_capability_end()
@ -58,12 +62,15 @@ def cap(events, event):
event["server"].server_capabilities.update(capabilities)
matched_caps = _match_caps(list(capabilities_keys))
event["server"].queue_capabilities(matched_caps)
events.on("received.cap.new").call(server=event["server"],
capabilities=capabilities)
module_caps = events.on("received.cap.new").call(
server=event["server"], capabilities=capabilities)
module_caps = list(filter(None, module_caps))
matched_caps.update(_match_caps(module_caps, capabilities_keys))
if event["server"].has_capability_queue():
_caps_offered(event["server"], matched_caps)
if event["server"].capability_queue:
event["server"].send_capability_queue()
elif subcommand == "del":
for capability in capabilities.keys():
@ -78,11 +85,17 @@ def cap(events, event):
server=event["server"])
if subcommand == "ack" or subcommand == "nak":
ack = subcommand == "ack"
for capability in capabilities:
event["server"].requested_capabilities.remove(capability)
cap_obj = event["server"].capability_queue[capability]
del event["server"].capability_queue[capability]
if ack:
cap_obj.ack()
else:
cap_obj.nak()
if (event["server"].cap_started and
not event["server"].requested_capabilities and
not event["server"].capability_queue and
not event["server"].waiting_for_capabilities()):
event["server"].cap_started = False
event["server"].send_capability_end()

View file

@ -23,12 +23,12 @@ class Module(ModuleManager.BaseModule):
@utils.hook("received.cap.ls")
def on_cap_ls(self, event):
if CAP in event["capabilities"]:
event["server"].queue_capability(CAP)
cap = utils.irc.Capability(CAP)
cap.on_ack(lambda: self._cap_ack(event["server"]))
return cap
@utils.hook("received.cap.ack")
def on_cap_ack(self, event):
if CAP in event["capabilities"]:
event["server"].wait_for_capability("resume")
def _cap_ack(self, server):
server.wait_for_capability("resume")
@utils.hook("received.resume")
def on_resume(self, event):

View file

@ -41,22 +41,22 @@ class Module(ModuleManager.BaseModule):
do_sasl = True
if do_sasl:
event["server"].queue_capability("sasl")
cap = utils.irc.Capability("sasl")
cap.on_ack(lambda: self._sasl_ack(event["server"]))
return cap
@utils.hook("received.cap.ack")
def on_cap_ack(self, event):
if "sasl" in event["capabilities"]:
sasl = event["server"].get_setting("sasl")
mechanism = sasl["mechanism"].upper()
if mechanism == "USERPASS":
server_mechanisms = event["server"].server_capabilities["sasl"]
server_mechanisms = server_mechanisms or [
USERPASS_MECHANISMS[0]]
mechanism = self._best_userpass_mechanism(server_mechanisms)
def _sasl_ack(self, server):
sasl = server.get_setting("sasl")
mechanism = sasl["mechanism"].upper()
if mechanism == "USERPASS":
server_mechanisms = server.server_capabilities["sasl"]
server_mechanisms = server_mechanisms or [
USERPASS_MECHANISMS[0]]
mechanism = self._best_userpass_mechanism(server_mechanisms)
event["server"].send_authenticate(mechanism)
event["server"].sasl_mechanism = mechanism
event["server"].wait_for_capability("sasl")
server.send_authenticate(mechanism)
server.sasl_mechanism = mechanism
server.wait_for_capability("sasl")
@utils.hook("received.authenticate")
def on_authenticate(self, event):

View file

@ -24,10 +24,10 @@ class Server(IRCObject.Object):
self.realname = None # type: typing.Optional[str]
self.hostname = None # type: typing.Optional[str]
self._capability_queue = set([]) # type: typing.Set[str]
self.capability_queue = {
} # type: typing.Dict[str, utils.irc.Capability]
self._capabilities_waiting = set([]) # type: typing.Set[str]
self.agreed_capabilities = set([]) # type: typing.Set[str]
self.requested_capabilities = [] # type: typing.List[str]
self.server_capabilities = {} # type: typing.Dict[str, str]
self.batches = {} # type: typing.Dict[str, IRCLine.ParsedLine]
self.cap_started = False
@ -269,21 +269,12 @@ class Server(IRCObject.Object):
def send_capibility_ls(self) -> IRCLine.SentLine:
return self.send(utils.irc.protocol.capability_ls())
def queue_capability(self, capability: str):
self._capability_queue.add(capability)
def queue_capabilities(self, capabilities: typing.List[str]):
self._capability_queue.update(capabilities)
def send_capability_queue(self):
if self.has_capability_queue():
capability_queue = list(self._capability_queue)
self._capability_queue.clear()
capability_queue = [cap for cap in self.capability_queue.keys()]
for i in range(0, len(capability_queue), 10):
capability_batch = capability_queue[i:i+10]
self.requested_capabilities += capability_batch
self.send_capability_request(" ".join(capability_batch))
def has_capability_queue(self):
return bool(len(self._capability_queue))
for i in range(0, len(capability_queue), 10):
capability_batch = capability_queue[i:i+10]
self.send_capability_request(" ".join(capability_batch))
def send_capability_request(self, capability: str) -> IRCLine.SentLine:
return self.send(utils.irc.protocol.capability_request(capability))
def send_capability_end(self) -> IRCLine.SentLine:

View file

@ -281,8 +281,17 @@ class Capability(object):
self._caps = set([name, draft_name])
self._name = name
self._draft_name = draft_name
self._on_ack_callbacks = []
def available(self, capabilities: typing.List[str]) -> str:
match = list(set(capabilities)&self._caps)
return match[0] if match else None
def enabled(self, capability: str) -> bool:
return capability in self._caps
def on_ack(self, callback: typing.Callable[[], None]):
self._on_ack_callbacks.append(callback)
def ack(self):
for callback in self._on_ack_callbacks:
callback()
def nak(self):
pass