split out command_spec module

This commit is contained in:
jesopo 2020-02-14 21:57:06 +00:00
parent db5787a594
commit f827bdce7f
3 changed files with 113 additions and 73 deletions

View file

@ -1,4 +1,5 @@
from src import EventManager, ModuleManager, utils from src import EventManager, ModuleManager, utils
from . import types
# describing command arg specifications, to centralise parsing and validating. # describing command arg specifications, to centralise parsing and validating.
# #
@ -43,72 +44,12 @@ class Module(ModuleManager.BaseModule):
value = simple_value value = simple_value
n = simple_count n = simple_count
error = argument_type.error() error = argument_type.error()
elif argument_type.type == "rchannel": elif argument_type.type in types.TYPES:
if channel: func = types.TYPES[argument_type.type]
value = channel try:
elif args: value, n = func(server, channel, user, args)
n = 1 except types.SpecTypeError as e:
if args[0] in server.channels: error = e.message
value = server.channels.get(args[0])
error = "No such channel"
else:
error = "No channel provided"
elif argument_type.type == "channel" and args:
if args[0] in server.channels:
value = server.channels.get(args[0])
n = 1
error = "No such channel"
elif argument_type.type == "cuser" and args:
tuser = server.get_user(args[0], create=False)
if tuser and channel.has_user(tuser):
value = tuser
n = 1
error = "That user is not in this channel"
elif argument_type.type == "ruser":
if args:
value = server.get_user(args[0], create=False)
n = 1
else:
value = user
error = "No such user"
elif argument_type.type == "user":
if args:
value = server.get_user(args[0], create=False)
n = 1
error = "No such user"
else:
error = "No user provided"
elif argument_type.type == "ouser":
if args:
if server.has_user_id(args[0]):
value = server.get_user(args[0], create=True)
error = "Unknown nickname"
n = 1
elif argument_type.type == "nuser":
if args:
value = server.get_user(args[0], create=True)
n = 1
elif argument_type.type == "lstring":
if args:
value = " ".join(args)
n = len(args)
else:
last_message = (channel or user).buffer.get()
if last_message:
value = last_message.message
n = 0
else:
n = 1
elif argument_type.type == "channelonly":
if channel:
value = True
n = 0
error = "Command not valid in PM"
elif argument_type.type == "privateonly":
if not channel:
value = True
n = 0
error = "Command not valid in-channel"
options.append([argument_type, value, n, error]) options.append([argument_type, value, n, error])
return options return options

View file

@ -0,0 +1,99 @@
class SpecTypeError(Exception):
def __init__(self, message: str, arg_count: int=1):
self.message = message
self.arg_count = arg_count
TYPES = {}
def _type(func):
TYPES[func.__name__] = func
def _assert_args(args, type):
if not args:
raise SpecTypeError("No %s provided" % type)
@_type
def rchannel(server, channel, user, args):
if channel:
return channel, 0
elif args:
if args[0] in server.channels:
return server.channels.get(args[0]), 1
else:
raise SpecTypeError("No such channel")
else:
raise SpecTypeError("No channel provided")
@_type
def channel(server, channel, user, args):
_assert_args(args, "channel")
if args[0] in server.channels:
return server.channels.get(args[0]), 1
else:
raise SpecTypeError("No such channel")
@_type
def cuser(server, channel, user, args):
_assert_args(args, "user")
target_user = server.get_user(args[0], create=False)
if target_user and channel.has_user(target_user):
return target_user, 1
else:
raise SpecTypeError("That user is not in this channel")
@_type
def ruser(server, channel, user, args):
if args:
target_user = server.get_user(args[0], create=False)
if target_user:
return target_user, 1
else:
raise SpecTypeError("No such user")
else:
return user, 0
@_type
def user(server, channel, user, args):
_assert_args(args, "user")
target_user = server.get_user(args[0], create=False)
if target_user:
return target_user, 1
else:
raise SpecTypeError("No such user")
@_type
def ouser(server, channel, user, args):
_assert_args(args, "user")
if server.has_user_id(args[0]):
return server.get_user(args[0], create=True), 1
else:
raise SpecTypeError("No such user")
@_type
def nuser(server, channel, user, args):
_assert_args(args, "user")
return server.get_user(args[0], create=True), 1
@_type
def lstring(server, channel, user, args):
if args:
return " ".join(args), len(args)
else:
last_message = (channel or user).buffer.get()
if last_message:
return last_message.message, 0
else:
raise SpecTypeError("No message found")
@_type
def channelonly(server, channel, user, args):
if channel:
return True, 0
else:
raise SpecTypeError("Command not valid in PM")
@_type
def privateonly(server, channel, user, args):
if not channel:
return True, 0
else:
raise SpecTypeError("Command not valid in channel")

View file

@ -25,17 +25,17 @@ class SpecArgumentType(object):
return None return None
class SpecArgumentTypeWord(SpecArgumentType): class SpecArgumentTypeWord(SpecArgumentType):
def simple(self, args: typing.List[str]) -> typing.Tuple[typing.Any, int]: def simple(self, args):
if args: if args:
return args[0], 1 return args[0], 1
return None, 1 return None, 1
class SpecArgumentTypeAdditionalWord(SpecArgumentType): class SpecArgumentTypeAdditionalWord(SpecArgumentType):
def simple(self, args: typing.List[str]) -> typing.Tuple[typing.Any, int]: def simple(self, args):
if len(args) > 1: if len(args) > 1:
return args[0], 1 return args[0], 1
return None, 1 return None, 1
class SpecArgumentTypeWordLower(SpecArgumentTypeWord): class SpecArgumentTypeWordLower(SpecArgumentTypeWord):
def simple(self, args: typing.List[str]) -> typing.Tuple[typing.Any, int]: def simple(self, args):
out = SpecArgumentTypeWord.simple(self, args) out = SpecArgumentTypeWord.simple(self, args)
if out[0]: if out[0]:
return out[0].lower(), out[1] return out[0].lower(), out[1]
@ -44,15 +44,15 @@ class SpecArgumentTypeWordLower(SpecArgumentTypeWord):
class SpecArgumentTypeString(SpecArgumentType): class SpecArgumentTypeString(SpecArgumentType):
def name(self): def name(self):
return "%s ..." % SpecArgumentType.name(self) return "%s ..." % SpecArgumentType.name(self)
def simple(self, args: typing.List[str]) -> typing.Tuple[typing.Any, int]: def simple(self, args):
if args: if args:
return " ".join(args), len(args) return " ".join(args), len(args)
return None, 1 return None, 1
class SpecArgumentTypeTrimString(SpecArgumentTypeString): class SpecArgumentTypeTrimString(SpecArgumentTypeString):
def simple(self, args: typing.List[str]): def simple(self, args):
return SpecArgumentTypeString.simple(self, list(filter(None, args))) return SpecArgumentTypeString.simple(self, list(filter(None, args)))
class SpecArgumentTypeWords(SpecArgumentTypeString): class SpecArgumentTypeWords(SpecArgumentTypeString):
def simple(self, args: typing.List[str]): def simple(self, args):
if args: if args:
out = list(filter(None, args)) out = list(filter(None, args))
return out, len(out) return out, len(out)
@ -67,7 +67,7 @@ class SpecArgumentTypeInt(SpecArgumentType):
class SpecArgumentTypeDuration(SpecArgumentType): class SpecArgumentTypeDuration(SpecArgumentType):
def name(self): def name(self):
return "+%s" % (SpecArgumentType.name(self) or "duration") return "+%s" % (SpecArgumentType.name(self) or "duration")
def simple(self, args: typing.List[str]) -> typing.Tuple[typing.Any, int]: def simple(self, args):
if args: if args:
return duration(args[0]), 1 return duration(args[0]), 1
return None, 1 return None, 1