Catch yield
s in command callbacks for e.g. permission checks
This commit is contained in:
parent
6057d86edc
commit
d7fa2cfa24
3 changed files with 87 additions and 27 deletions
|
@ -13,9 +13,7 @@ LOWHIGH = {
|
|||
"help": "Set which channel mode is considered to be 'high' access",
|
||||
"example": "o"})
|
||||
class Module(ModuleManager.BaseModule):
|
||||
@utils.hook("preprocess.command")
|
||||
def preprocess_command(self, event):
|
||||
require_mode = event["hook"].get_kwarg("require_mode")
|
||||
def _check_command(self, event, require_mode):
|
||||
if event["is_channel"] and require_mode:
|
||||
if require_mode.lower() in LOWHIGH:
|
||||
require_mode = event["target"].get_setting(
|
||||
|
@ -27,3 +25,13 @@ class Module(ModuleManager.BaseModule):
|
|||
return "You do not have permission to do this"
|
||||
else:
|
||||
return utils.consts.PERMISSION_FORCE_SUCCESS
|
||||
|
||||
@utils.hook("preprocess.command")
|
||||
def preprocess_command(self, event):
|
||||
require_mode = event["hook"].get_kwarg("require_mode")
|
||||
if not require_mode == None:
|
||||
return self._check_command(event, require_mode)
|
||||
|
||||
@utils.hook("check.command.channel-mode")
|
||||
def check_command(self, event):
|
||||
return self._check_command(event, event["check_args"][0])
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#--depends-on config
|
||||
#--depends-on permissions
|
||||
|
||||
import re, string
|
||||
import re, string, types
|
||||
from src import EventManager, ModuleManager, utils
|
||||
from . import outs
|
||||
|
||||
|
@ -130,6 +130,60 @@ class Module(ModuleManager.BaseModule):
|
|||
|
||||
return hook, args_split
|
||||
|
||||
def _check(self, context, kwargs, request=None, **extra_kwargs):
|
||||
event_hook = self.events.on(context).on("command")
|
||||
if not request == None:
|
||||
event_hook = event_hook.on(request)
|
||||
|
||||
returns = event_hook.call_unsafe(**kwargs, **extra_kwargs)
|
||||
|
||||
hard_fail = False
|
||||
force_success = False
|
||||
error = None
|
||||
for returned in returns:
|
||||
if returned == utils.consts.PERMISSION_HARD_FAIL:
|
||||
hard_fail = True
|
||||
break
|
||||
elif returned == utils.consts.PERMISSION_FORCE_SUCCESS:
|
||||
force_success = True
|
||||
elif returned:
|
||||
error = returned
|
||||
|
||||
if hard_fail or (not force_success and error):
|
||||
if error:
|
||||
return False, error
|
||||
return False, None
|
||||
return True, None
|
||||
|
||||
def _hook_call_catch(self, func):
|
||||
try:
|
||||
return True, func()
|
||||
except utils.EventError as e:
|
||||
return False, str(e)
|
||||
def _hook_call(self, hook, event, check_kwargs):
|
||||
hook_success, hook_return = self._hook_call_catch(
|
||||
lambda: hook.call(event))
|
||||
|
||||
if hook_success and hook_return and isinstance(
|
||||
hook_return, types.GeneratorType):
|
||||
while True:
|
||||
try:
|
||||
next_success, next_return = self._hook_call_catch(
|
||||
lambda: next(hook_return))
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
if next_success:
|
||||
if isinstance(next_return, utils.Check):
|
||||
check_success, check_message = self._check("check",
|
||||
check_kwargs, next_return.request,
|
||||
check_args=next_return.args)
|
||||
if not check_success:
|
||||
return False, check_message
|
||||
else:
|
||||
break
|
||||
return hook_success, hook_return
|
||||
|
||||
def command(self, server, target, target_str, is_channel, user, command,
|
||||
args_split, tags, hook, **kwargs):
|
||||
if self._is_ignored(server, user, command):
|
||||
|
@ -175,26 +229,16 @@ class Module(ModuleManager.BaseModule):
|
|||
stderr.write("Not enough arguments (minimum: %d)" %
|
||||
min_args).send(command_method)
|
||||
else:
|
||||
returns = self.events.on("preprocess.command").call_unsafe(
|
||||
hook=hook, user=user, server=server, target=target,
|
||||
is_channel=is_channel, tags=tags, args_split=args_split,
|
||||
command=command, **kwargs)
|
||||
check_kwargs = {"hook": hook, "user": user, "server": server,
|
||||
"target": target, "is_channel": is_channel, "tags": tags,
|
||||
"args_split": args_split, "command": command}
|
||||
check_kwargs.update(kwargs)
|
||||
|
||||
hard_fail = False
|
||||
force_success = False
|
||||
error = None
|
||||
for returned in returns:
|
||||
if returned == utils.consts.PERMISSION_HARD_FAIL:
|
||||
hard_fail = True
|
||||
break
|
||||
elif returned == utils.consts.PERMISSION_FORCE_SUCCESS:
|
||||
force_success = True
|
||||
elif returned:
|
||||
error = returned
|
||||
|
||||
if hard_fail or (not force_success and error):
|
||||
if error:
|
||||
stderr.write(error).send(command_method)
|
||||
check_success, check_message = self._check("preprocess",
|
||||
check_kwargs)
|
||||
if not check_success:
|
||||
if check_message:
|
||||
stderr.write(check_message).send(command_method)
|
||||
return True
|
||||
|
||||
args = " ".join(args_split)
|
||||
|
@ -206,10 +250,13 @@ class Module(ModuleManager.BaseModule):
|
|||
|
||||
self.log.trace("calling command '%s': %s",
|
||||
[command, new_event.kwargs])
|
||||
try:
|
||||
hook.call(new_event)
|
||||
except utils.EventError as e:
|
||||
stderr.write(str(e))
|
||||
hook_success, hook_message = self._hook_call(hook, new_event,
|
||||
check_kwargs)
|
||||
|
||||
if not hook_success:
|
||||
if not hook_message == None:
|
||||
stderr.write(hook_message).send(command_method)
|
||||
return True
|
||||
|
||||
if not hook.kwargs.get("skip_out", False):
|
||||
had_out = stdout.has_text() or stderr.has_text()
|
||||
|
|
|
@ -185,6 +185,11 @@ def export(setting: str, value: typing.Any):
|
|||
return module
|
||||
return _export_func
|
||||
|
||||
class Check(object):
|
||||
def __init__(self, request: str, *args: typing.List[str]):
|
||||
self.request = request
|
||||
self.args = args
|
||||
|
||||
TOP_10_CALLABLE = typing.Callable[[typing.Any], typing.Any]
|
||||
def top_10(items: typing.Dict[typing.Any, typing.Any],
|
||||
convert_key: TOP_10_CALLABLE=lambda x: x,
|
||||
|
|
Loading…
Reference in a new issue