Revamp how CAPs are tracked through REQ and ACK/NAK etc
This commit is contained in:
parent
6ef7f8374d
commit
d291cd5063
5 changed files with 68 additions and 55 deletions
|
@ -19,14 +19,20 @@ CAPABILITIES = [
|
||||||
utils.irc.Capability(None, "draft/setname")
|
utils.irc.Capability(None, "draft/setname")
|
||||||
]
|
]
|
||||||
|
|
||||||
def _match_caps(capabilities):
|
def _match_caps(our_capabilities, offered_capabilities):
|
||||||
matched = []
|
matched = {}
|
||||||
for capability in CAPABILITIES:
|
for capability in our_capabilities:
|
||||||
available = capability.available(capabilities)
|
available = capability.available(offered_capabilities)
|
||||||
if available:
|
if available:
|
||||||
matched.append(available)
|
matched[available] = capability
|
||||||
return matched
|
return matched
|
||||||
|
|
||||||
|
def _caps_offered(server, caps):
|
||||||
|
blacklist = server.get_setting("blacklisted-caps", [])
|
||||||
|
for cap_name, cap in caps.items():
|
||||||
|
if not cap_name in blacklist:
|
||||||
|
server.capability_queue[cap_name] = cap
|
||||||
|
|
||||||
def cap(events, event):
|
def cap(events, event):
|
||||||
capabilities = utils.parse.keyvalue(event["args"][-1])
|
capabilities = utils.parse.keyvalue(event["args"][-1])
|
||||||
subcommand = event["args"][1].lower()
|
subcommand = event["args"][1].lower()
|
||||||
|
@ -36,20 +42,18 @@ def cap(events, event):
|
||||||
event["server"].cap_started = True
|
event["server"].cap_started = True
|
||||||
event["server"].server_capabilities.update(capabilities)
|
event["server"].server_capabilities.update(capabilities)
|
||||||
if not is_multiline:
|
if not is_multiline:
|
||||||
matched_caps = _match_caps(
|
server_caps = list(event["server"].server_capabilities.keys())
|
||||||
list(event["server"].server_capabilities.keys()))
|
matched_caps = _match_caps(CAPABILITIES, server_caps)
|
||||||
blacklisted_caps = event["server"].get_setting(
|
|
||||||
"blacklisted-caps", [])
|
|
||||||
matched_caps = list(
|
|
||||||
set(matched_caps)-set(blacklisted_caps))
|
|
||||||
|
|
||||||
event["server"].queue_capabilities(matched_caps)
|
module_caps = events.on("received.cap.ls").call(
|
||||||
|
|
||||||
events.on("received.cap.ls").call(
|
|
||||||
capabilities=event["server"].server_capabilities,
|
capabilities=event["server"].server_capabilities,
|
||||||
server=event["server"])
|
server=event["server"])
|
||||||
|
module_caps = list(filter(None, module_caps))
|
||||||
|
matched_caps.update(_match_caps(module_caps, server_caps))
|
||||||
|
|
||||||
if event["server"].has_capability_queue():
|
_caps_offered(event["server"], matched_caps)
|
||||||
|
|
||||||
|
if event["server"].capability_queue:
|
||||||
event["server"].send_capability_queue()
|
event["server"].send_capability_queue()
|
||||||
else:
|
else:
|
||||||
event["server"].send_capability_end()
|
event["server"].send_capability_end()
|
||||||
|
@ -58,12 +62,15 @@ def cap(events, event):
|
||||||
event["server"].server_capabilities.update(capabilities)
|
event["server"].server_capabilities.update(capabilities)
|
||||||
|
|
||||||
matched_caps = _match_caps(list(capabilities_keys))
|
matched_caps = _match_caps(list(capabilities_keys))
|
||||||
event["server"].queue_capabilities(matched_caps)
|
|
||||||
|
|
||||||
events.on("received.cap.new").call(server=event["server"],
|
module_caps = events.on("received.cap.new").call(
|
||||||
capabilities=capabilities)
|
server=event["server"], capabilities=capabilities)
|
||||||
|
module_caps = list(filter(None, module_caps))
|
||||||
|
matched_caps.update(_match_caps(module_caps, capabilities_keys))
|
||||||
|
|
||||||
if event["server"].has_capability_queue():
|
_caps_offered(event["server"], matched_caps)
|
||||||
|
|
||||||
|
if event["server"].capability_queue:
|
||||||
event["server"].send_capability_queue()
|
event["server"].send_capability_queue()
|
||||||
elif subcommand == "del":
|
elif subcommand == "del":
|
||||||
for capability in capabilities.keys():
|
for capability in capabilities.keys():
|
||||||
|
@ -78,11 +85,17 @@ def cap(events, event):
|
||||||
server=event["server"])
|
server=event["server"])
|
||||||
|
|
||||||
if subcommand == "ack" or subcommand == "nak":
|
if subcommand == "ack" or subcommand == "nak":
|
||||||
|
ack = subcommand == "ack"
|
||||||
for capability in capabilities:
|
for capability in capabilities:
|
||||||
event["server"].requested_capabilities.remove(capability)
|
cap_obj = event["server"].capability_queue[capability]
|
||||||
|
del event["server"].capability_queue[capability]
|
||||||
|
if ack:
|
||||||
|
cap_obj.ack()
|
||||||
|
else:
|
||||||
|
cap_obj.nak()
|
||||||
|
|
||||||
if (event["server"].cap_started and
|
if (event["server"].cap_started and
|
||||||
not event["server"].requested_capabilities and
|
not event["server"].capability_queue and
|
||||||
not event["server"].waiting_for_capabilities()):
|
not event["server"].waiting_for_capabilities()):
|
||||||
event["server"].cap_started = False
|
event["server"].cap_started = False
|
||||||
event["server"].send_capability_end()
|
event["server"].send_capability_end()
|
||||||
|
|
|
@ -23,12 +23,12 @@ class Module(ModuleManager.BaseModule):
|
||||||
@utils.hook("received.cap.ls")
|
@utils.hook("received.cap.ls")
|
||||||
def on_cap_ls(self, event):
|
def on_cap_ls(self, event):
|
||||||
if CAP in event["capabilities"]:
|
if CAP in event["capabilities"]:
|
||||||
event["server"].queue_capability(CAP)
|
cap = utils.irc.Capability(CAP)
|
||||||
|
cap.on_ack(lambda: self._cap_ack(event["server"]))
|
||||||
|
return cap
|
||||||
|
|
||||||
@utils.hook("received.cap.ack")
|
def _cap_ack(self, server):
|
||||||
def on_cap_ack(self, event):
|
server.wait_for_capability("resume")
|
||||||
if CAP in event["capabilities"]:
|
|
||||||
event["server"].wait_for_capability("resume")
|
|
||||||
|
|
||||||
@utils.hook("received.resume")
|
@utils.hook("received.resume")
|
||||||
def on_resume(self, event):
|
def on_resume(self, event):
|
||||||
|
|
|
@ -41,22 +41,22 @@ class Module(ModuleManager.BaseModule):
|
||||||
do_sasl = True
|
do_sasl = True
|
||||||
|
|
||||||
if do_sasl:
|
if do_sasl:
|
||||||
event["server"].queue_capability("sasl")
|
cap = utils.irc.Capability("sasl")
|
||||||
|
cap.on_ack(lambda: self._sasl_ack(event["server"]))
|
||||||
|
return cap
|
||||||
|
|
||||||
@utils.hook("received.cap.ack")
|
def _sasl_ack(self, server):
|
||||||
def on_cap_ack(self, event):
|
sasl = server.get_setting("sasl")
|
||||||
if "sasl" in event["capabilities"]:
|
|
||||||
sasl = event["server"].get_setting("sasl")
|
|
||||||
mechanism = sasl["mechanism"].upper()
|
mechanism = sasl["mechanism"].upper()
|
||||||
if mechanism == "USERPASS":
|
if mechanism == "USERPASS":
|
||||||
server_mechanisms = event["server"].server_capabilities["sasl"]
|
server_mechanisms = server.server_capabilities["sasl"]
|
||||||
server_mechanisms = server_mechanisms or [
|
server_mechanisms = server_mechanisms or [
|
||||||
USERPASS_MECHANISMS[0]]
|
USERPASS_MECHANISMS[0]]
|
||||||
mechanism = self._best_userpass_mechanism(server_mechanisms)
|
mechanism = self._best_userpass_mechanism(server_mechanisms)
|
||||||
|
|
||||||
event["server"].send_authenticate(mechanism)
|
server.send_authenticate(mechanism)
|
||||||
event["server"].sasl_mechanism = mechanism
|
server.sasl_mechanism = mechanism
|
||||||
event["server"].wait_for_capability("sasl")
|
server.wait_for_capability("sasl")
|
||||||
|
|
||||||
@utils.hook("received.authenticate")
|
@utils.hook("received.authenticate")
|
||||||
def on_authenticate(self, event):
|
def on_authenticate(self, event):
|
||||||
|
|
|
@ -24,10 +24,10 @@ class Server(IRCObject.Object):
|
||||||
self.realname = None # type: typing.Optional[str]
|
self.realname = None # type: typing.Optional[str]
|
||||||
self.hostname = None # type: typing.Optional[str]
|
self.hostname = None # type: typing.Optional[str]
|
||||||
|
|
||||||
self._capability_queue = set([]) # type: typing.Set[str]
|
self.capability_queue = {
|
||||||
|
} # type: typing.Dict[str, utils.irc.Capability]
|
||||||
self._capabilities_waiting = set([]) # type: typing.Set[str]
|
self._capabilities_waiting = set([]) # type: typing.Set[str]
|
||||||
self.agreed_capabilities = set([]) # type: typing.Set[str]
|
self.agreed_capabilities = set([]) # type: typing.Set[str]
|
||||||
self.requested_capabilities = [] # type: typing.List[str]
|
|
||||||
self.server_capabilities = {} # type: typing.Dict[str, str]
|
self.server_capabilities = {} # type: typing.Dict[str, str]
|
||||||
self.batches = {} # type: typing.Dict[str, IRCLine.ParsedLine]
|
self.batches = {} # type: typing.Dict[str, IRCLine.ParsedLine]
|
||||||
self.cap_started = False
|
self.cap_started = False
|
||||||
|
@ -269,21 +269,12 @@ class Server(IRCObject.Object):
|
||||||
|
|
||||||
def send_capibility_ls(self) -> IRCLine.SentLine:
|
def send_capibility_ls(self) -> IRCLine.SentLine:
|
||||||
return self.send(utils.irc.protocol.capability_ls())
|
return self.send(utils.irc.protocol.capability_ls())
|
||||||
def queue_capability(self, capability: str):
|
|
||||||
self._capability_queue.add(capability)
|
|
||||||
def queue_capabilities(self, capabilities: typing.List[str]):
|
|
||||||
self._capability_queue.update(capabilities)
|
|
||||||
def send_capability_queue(self):
|
def send_capability_queue(self):
|
||||||
if self.has_capability_queue():
|
capability_queue = [cap for cap in self.capability_queue.keys()]
|
||||||
capability_queue = list(self._capability_queue)
|
|
||||||
self._capability_queue.clear()
|
|
||||||
|
|
||||||
for i in range(0, len(capability_queue), 10):
|
for i in range(0, len(capability_queue), 10):
|
||||||
capability_batch = capability_queue[i:i+10]
|
capability_batch = capability_queue[i:i+10]
|
||||||
self.requested_capabilities += capability_batch
|
|
||||||
self.send_capability_request(" ".join(capability_batch))
|
self.send_capability_request(" ".join(capability_batch))
|
||||||
def has_capability_queue(self):
|
|
||||||
return bool(len(self._capability_queue))
|
|
||||||
def send_capability_request(self, capability: str) -> IRCLine.SentLine:
|
def send_capability_request(self, capability: str) -> IRCLine.SentLine:
|
||||||
return self.send(utils.irc.protocol.capability_request(capability))
|
return self.send(utils.irc.protocol.capability_request(capability))
|
||||||
def send_capability_end(self) -> IRCLine.SentLine:
|
def send_capability_end(self) -> IRCLine.SentLine:
|
||||||
|
|
|
@ -281,8 +281,17 @@ class Capability(object):
|
||||||
self._caps = set([name, draft_name])
|
self._caps = set([name, draft_name])
|
||||||
self._name = name
|
self._name = name
|
||||||
self._draft_name = draft_name
|
self._draft_name = draft_name
|
||||||
|
self._on_ack_callbacks = []
|
||||||
def available(self, capabilities: typing.List[str]) -> str:
|
def available(self, capabilities: typing.List[str]) -> str:
|
||||||
match = list(set(capabilities)&self._caps)
|
match = list(set(capabilities)&self._caps)
|
||||||
return match[0] if match else None
|
return match[0] if match else None
|
||||||
def enabled(self, capability: str) -> bool:
|
def enabled(self, capability: str) -> bool:
|
||||||
return capability in self._caps
|
return capability in self._caps
|
||||||
|
|
||||||
|
def on_ack(self, callback: typing.Callable[[], None]):
|
||||||
|
self._on_ack_callbacks.append(callback)
|
||||||
|
def ack(self):
|
||||||
|
for callback in self._on_ack_callbacks:
|
||||||
|
callback()
|
||||||
|
def nak(self):
|
||||||
|
pass
|
||||||
|
|
Loading…
Reference in a new issue