diff --git a/modules/dice.py b/modules/dice.py index 8e34c9e0..8a787d94 100644 --- a/modules/dice.py +++ b/modules/dice.py @@ -15,43 +15,37 @@ class Module(ModuleManager.BaseModule): @utils.hook("received.command.roll") @utils.hook("received.command.dice", alias_of="roll") @utils.kwarg("help", "Roll dice DND-style") - @utils.kwarg("usage", "[1-%s]d[1-%d]" % (MAX_DICE, MAX_SIDES)) + @utils.spec("?<1d20>pattern(^(\d+)d(\d+)((?:\s*[-+]\d+)*))") def roll_dice(self, event): - args = None - if event["args_split"]: - args = event["args"] - else: - args = "1d6" + dice_count, side_count = 1, 6 + roll = "1d6" + modifiers = [] - match = RE_DICE.match(args) - if match: - roll = match.group(0) - dice_count = int(match.group(1) or "1") - side_count = int(match.group(2)) - modifiers_str = "".join(match.group(3).split()) - modifiers = RE_MODIFIERS.findall(modifiers_str) + if event["spec"][0]: + dice_count = int(event["spec"][0].group(1) or "1") + side_count = int(event["spec"][0].group(2)) + roll = event["spec"][0].group(0) + modifiers = RE_MODIFIERS.findall(event["spec"][0].group(3)) - if dice_count > 6: - raise utils.EventError("Max number of dice is %s" % MAX_DICE) - if side_count > MAX_SIDES: - raise utils.EventError("Max number of sides is %s" - % MAX_SIDES) + if dice_count > 6: + raise utils.EventError("Max number of dice is %s" % MAX_DICE) + if side_count > MAX_SIDES: + raise utils.EventError("Max number of sides is %s" + % MAX_SIDES) - results = random.choices(range(1, side_count+1), k=dice_count) + results = random.choices(range(1, side_count+1), k=dice_count) - total_n = sum(results) - for modifier in modifiers: - if modifier[0] == "+": - total_n += int(modifier[1:]) - else: - total_n -= int(modifier[1:]) + total_n = sum(results) + for modifier in modifiers: + if modifier[0] == "+": + total_n += int(modifier[1:]) + else: + total_n -= int(modifier[1:]) - total = "" - if len(results) > 1 or modifiers: - total = " (total: %d)" % total_n + total = "" + if len(results) > 1 or modifiers: + total = " (total: %d)" % total_n - results_str = ", ".join(str(r) for r in results) - event["stdout"].write("Rolled %s and got %s%s" % ( - roll, results_str, total)) - else: - event["stderr"].write("Invalid format. Example: 2d12+2") + results_str = ", ".join(str(r) for r in results) + event["stdout"].write("Rolled %s and got %s%s" % ( + roll, results_str, total)) diff --git a/src/utils/parse/spec.py b/src/utils/parse/spec.py index 407eec0f..e03c8e23 100644 --- a/src/utils/parse/spec.py +++ b/src/utils/parse/spec.py @@ -1,4 +1,4 @@ -import enum, typing +import enum, re, typing from .time import duration from .types import try_int from src.utils.datetime.parse import date_human @@ -17,11 +17,15 @@ class SpecArgumentType(object): context = SpecArgumentContext.ALL def __init__(self, type_name: str, name: typing.Optional[str], - exported: typing.Optional[str]): + modifier: typing.Optional[str], exported: typing.Optional[str]): self.type = type_name self._name = name + self._set_modifier(modifier) self.exported = exported + def _set_modifier(self, modifier: str): + pass + def name(self) -> typing.Optional[str]: return self._name def simple(self, args: typing.List[str]) -> typing.Tuple[typing.Any, int]: @@ -29,6 +33,18 @@ class SpecArgumentType(object): def error(self) -> typing.Optional[str]: return None +class SpecArgumentTypePattern(SpecArgumentType): + _pattern: typing.Pattern + def _set_modifier(self, modifier): + print(modifier) + self._pattern = re.compile(modifier) + def simple(self, args): + match = self._pattern.search(" ".join(args)) + if match: + return match, match.group(0).rstrip(" ").count(" ") + else: + return None, 1 + class SpecArgumentTypeWord(SpecArgumentType): def simple(self, args): if args: @@ -99,7 +115,8 @@ SPEC_ARGUMENT_TYPES = { "tstring": SpecArgumentTypeTrimString, "int": SpecArgumentTypeInt, "date": SpecArgumentTypeDate, - "duration": SpecArgumentTypeDuration + "duration": SpecArgumentTypeDuration, + "pattern": SpecArgumentTypePattern } class SpecArgument(object): @@ -118,10 +135,16 @@ class SpecArgument(object): argument_type_name: typing.Optional[str] = None name_end = argument_type.find(">") - if argument_type.startswith("<") and name_end > 0: + if name_end > 0 and argument_type.startswith("<"): argument_type_name = argument_type[1:name_end] argument_type = argument_type[name_end+1:] + argument_type_modifier: typing.Optional[str] = None + modifier_start = argument_type.find("(") + if modifier_start > 0 and argument_type.endswith(")"): + argument_type_modifier = argument_type[modifier_start+1:-1] + argument_type = argument_type[:modifier_start] + argument_type_class = SpecArgumentType if argument_type in SPEC_ARGUMENT_TYPES: argument_type_class = SPEC_ARGUMENT_TYPES[argument_type] @@ -129,7 +152,7 @@ class SpecArgument(object): argument_type_class = SpecArgumentPrivateType out.append(argument_type_class(argument_type, - argument_type_name, exported)) + argument_type_name, argument_type_modifier, exported)) spec_argument = SpecArgument() spec_argument.optional = optional @@ -165,7 +188,7 @@ class SpecLiteralArgument(SpecArgument): spec_argument = SpecLiteralArgument() spec_argument.optional = optional spec_argument.types = [ - SpecArgumentTypeLiteral("literal", l, None) for l in literals] + SpecArgumentTypeLiteral("literal", l, None, None) for l in literals] return spec_argument def format(self, context: SpecArgumentContext) -> typing.Optional[str]: