Revamp how CAPs are tracked through REQ and ACK/NAK etc
This commit is contained in:
parent
6ef7f8374d
commit
d291cd5063
5 changed files with 68 additions and 55 deletions
|
@ -19,14 +19,20 @@ CAPABILITIES = [
|
|||
utils.irc.Capability(None, "draft/setname")
|
||||
]
|
||||
|
||||
def _match_caps(capabilities):
|
||||
matched = []
|
||||
for capability in CAPABILITIES:
|
||||
available = capability.available(capabilities)
|
||||
def _match_caps(our_capabilities, offered_capabilities):
|
||||
matched = {}
|
||||
for capability in our_capabilities:
|
||||
available = capability.available(offered_capabilities)
|
||||
if available:
|
||||
matched.append(available)
|
||||
matched[available] = capability
|
||||
return matched
|
||||
|
||||
def _caps_offered(server, caps):
|
||||
blacklist = server.get_setting("blacklisted-caps", [])
|
||||
for cap_name, cap in caps.items():
|
||||
if not cap_name in blacklist:
|
||||
server.capability_queue[cap_name] = cap
|
||||
|
||||
def cap(events, event):
|
||||
capabilities = utils.parse.keyvalue(event["args"][-1])
|
||||
subcommand = event["args"][1].lower()
|
||||
|
@ -36,20 +42,18 @@ def cap(events, event):
|
|||
event["server"].cap_started = True
|
||||
event["server"].server_capabilities.update(capabilities)
|
||||
if not is_multiline:
|
||||
matched_caps = _match_caps(
|
||||
list(event["server"].server_capabilities.keys()))
|
||||
blacklisted_caps = event["server"].get_setting(
|
||||
"blacklisted-caps", [])
|
||||
matched_caps = list(
|
||||
set(matched_caps)-set(blacklisted_caps))
|
||||
server_caps = list(event["server"].server_capabilities.keys())
|
||||
matched_caps = _match_caps(CAPABILITIES, server_caps)
|
||||
|
||||
event["server"].queue_capabilities(matched_caps)
|
||||
|
||||
events.on("received.cap.ls").call(
|
||||
module_caps = events.on("received.cap.ls").call(
|
||||
capabilities=event["server"].server_capabilities,
|
||||
server=event["server"])
|
||||
module_caps = list(filter(None, module_caps))
|
||||
matched_caps.update(_match_caps(module_caps, server_caps))
|
||||
|
||||
if event["server"].has_capability_queue():
|
||||
_caps_offered(event["server"], matched_caps)
|
||||
|
||||
if event["server"].capability_queue:
|
||||
event["server"].send_capability_queue()
|
||||
else:
|
||||
event["server"].send_capability_end()
|
||||
|
@ -58,12 +62,15 @@ def cap(events, event):
|
|||
event["server"].server_capabilities.update(capabilities)
|
||||
|
||||
matched_caps = _match_caps(list(capabilities_keys))
|
||||
event["server"].queue_capabilities(matched_caps)
|
||||
|
||||
events.on("received.cap.new").call(server=event["server"],
|
||||
capabilities=capabilities)
|
||||
module_caps = events.on("received.cap.new").call(
|
||||
server=event["server"], capabilities=capabilities)
|
||||
module_caps = list(filter(None, module_caps))
|
||||
matched_caps.update(_match_caps(module_caps, capabilities_keys))
|
||||
|
||||
if event["server"].has_capability_queue():
|
||||
_caps_offered(event["server"], matched_caps)
|
||||
|
||||
if event["server"].capability_queue:
|
||||
event["server"].send_capability_queue()
|
||||
elif subcommand == "del":
|
||||
for capability in capabilities.keys():
|
||||
|
@ -78,11 +85,17 @@ def cap(events, event):
|
|||
server=event["server"])
|
||||
|
||||
if subcommand == "ack" or subcommand == "nak":
|
||||
ack = subcommand == "ack"
|
||||
for capability in capabilities:
|
||||
event["server"].requested_capabilities.remove(capability)
|
||||
cap_obj = event["server"].capability_queue[capability]
|
||||
del event["server"].capability_queue[capability]
|
||||
if ack:
|
||||
cap_obj.ack()
|
||||
else:
|
||||
cap_obj.nak()
|
||||
|
||||
if (event["server"].cap_started and
|
||||
not event["server"].requested_capabilities and
|
||||
not event["server"].capability_queue and
|
||||
not event["server"].waiting_for_capabilities()):
|
||||
event["server"].cap_started = False
|
||||
event["server"].send_capability_end()
|
||||
|
|
|
@ -23,12 +23,12 @@ class Module(ModuleManager.BaseModule):
|
|||
@utils.hook("received.cap.ls")
|
||||
def on_cap_ls(self, event):
|
||||
if CAP in event["capabilities"]:
|
||||
event["server"].queue_capability(CAP)
|
||||
cap = utils.irc.Capability(CAP)
|
||||
cap.on_ack(lambda: self._cap_ack(event["server"]))
|
||||
return cap
|
||||
|
||||
@utils.hook("received.cap.ack")
|
||||
def on_cap_ack(self, event):
|
||||
if CAP in event["capabilities"]:
|
||||
event["server"].wait_for_capability("resume")
|
||||
def _cap_ack(self, server):
|
||||
server.wait_for_capability("resume")
|
||||
|
||||
@utils.hook("received.resume")
|
||||
def on_resume(self, event):
|
||||
|
|
|
@ -41,22 +41,22 @@ class Module(ModuleManager.BaseModule):
|
|||
do_sasl = True
|
||||
|
||||
if do_sasl:
|
||||
event["server"].queue_capability("sasl")
|
||||
cap = utils.irc.Capability("sasl")
|
||||
cap.on_ack(lambda: self._sasl_ack(event["server"]))
|
||||
return cap
|
||||
|
||||
@utils.hook("received.cap.ack")
|
||||
def on_cap_ack(self, event):
|
||||
if "sasl" in event["capabilities"]:
|
||||
sasl = event["server"].get_setting("sasl")
|
||||
mechanism = sasl["mechanism"].upper()
|
||||
if mechanism == "USERPASS":
|
||||
server_mechanisms = event["server"].server_capabilities["sasl"]
|
||||
server_mechanisms = server_mechanisms or [
|
||||
USERPASS_MECHANISMS[0]]
|
||||
mechanism = self._best_userpass_mechanism(server_mechanisms)
|
||||
def _sasl_ack(self, server):
|
||||
sasl = server.get_setting("sasl")
|
||||
mechanism = sasl["mechanism"].upper()
|
||||
if mechanism == "USERPASS":
|
||||
server_mechanisms = server.server_capabilities["sasl"]
|
||||
server_mechanisms = server_mechanisms or [
|
||||
USERPASS_MECHANISMS[0]]
|
||||
mechanism = self._best_userpass_mechanism(server_mechanisms)
|
||||
|
||||
event["server"].send_authenticate(mechanism)
|
||||
event["server"].sasl_mechanism = mechanism
|
||||
event["server"].wait_for_capability("sasl")
|
||||
server.send_authenticate(mechanism)
|
||||
server.sasl_mechanism = mechanism
|
||||
server.wait_for_capability("sasl")
|
||||
|
||||
@utils.hook("received.authenticate")
|
||||
def on_authenticate(self, event):
|
||||
|
|
|
@ -24,10 +24,10 @@ class Server(IRCObject.Object):
|
|||
self.realname = None # type: typing.Optional[str]
|
||||
self.hostname = None # type: typing.Optional[str]
|
||||
|
||||
self._capability_queue = set([]) # type: typing.Set[str]
|
||||
self.capability_queue = {
|
||||
} # type: typing.Dict[str, utils.irc.Capability]
|
||||
self._capabilities_waiting = set([]) # type: typing.Set[str]
|
||||
self.agreed_capabilities = set([]) # type: typing.Set[str]
|
||||
self.requested_capabilities = [] # type: typing.List[str]
|
||||
self.server_capabilities = {} # type: typing.Dict[str, str]
|
||||
self.batches = {} # type: typing.Dict[str, IRCLine.ParsedLine]
|
||||
self.cap_started = False
|
||||
|
@ -269,21 +269,12 @@ class Server(IRCObject.Object):
|
|||
|
||||
def send_capibility_ls(self) -> IRCLine.SentLine:
|
||||
return self.send(utils.irc.protocol.capability_ls())
|
||||
def queue_capability(self, capability: str):
|
||||
self._capability_queue.add(capability)
|
||||
def queue_capabilities(self, capabilities: typing.List[str]):
|
||||
self._capability_queue.update(capabilities)
|
||||
def send_capability_queue(self):
|
||||
if self.has_capability_queue():
|
||||
capability_queue = list(self._capability_queue)
|
||||
self._capability_queue.clear()
|
||||
capability_queue = [cap for cap in self.capability_queue.keys()]
|
||||
|
||||
for i in range(0, len(capability_queue), 10):
|
||||
capability_batch = capability_queue[i:i+10]
|
||||
self.requested_capabilities += capability_batch
|
||||
self.send_capability_request(" ".join(capability_batch))
|
||||
def has_capability_queue(self):
|
||||
return bool(len(self._capability_queue))
|
||||
for i in range(0, len(capability_queue), 10):
|
||||
capability_batch = capability_queue[i:i+10]
|
||||
self.send_capability_request(" ".join(capability_batch))
|
||||
def send_capability_request(self, capability: str) -> IRCLine.SentLine:
|
||||
return self.send(utils.irc.protocol.capability_request(capability))
|
||||
def send_capability_end(self) -> IRCLine.SentLine:
|
||||
|
|
|
@ -281,8 +281,17 @@ class Capability(object):
|
|||
self._caps = set([name, draft_name])
|
||||
self._name = name
|
||||
self._draft_name = draft_name
|
||||
self._on_ack_callbacks = []
|
||||
def available(self, capabilities: typing.List[str]) -> str:
|
||||
match = list(set(capabilities)&self._caps)
|
||||
return match[0] if match else None
|
||||
def enabled(self, capability: str) -> bool:
|
||||
return capability in self._caps
|
||||
|
||||
def on_ack(self, callback: typing.Callable[[], None]):
|
||||
self._on_ack_callbacks.append(callback)
|
||||
def ack(self):
|
||||
for callback in self._on_ack_callbacks:
|
||||
callback()
|
||||
def nak(self):
|
||||
pass
|
||||
|
|
Loading…
Reference in a new issue