rewrite command output truncation

This commit is contained in:
jesopo 2020-03-03 11:15:00 +00:00
parent ea87013249
commit 7bf0b6edbf
5 changed files with 58 additions and 48 deletions

View file

@ -2,7 +2,7 @@ import datetime, typing, uuid
from src import EventManager, IRCObject, utils from src import EventManager, IRCObject, utils
# this should be 510 (RFC1459, 512 with \r\n) but a server BitBot uses is broken # this should be 510 (RFC1459, 512 with \r\n) but a server BitBot uses is broken
LINE_MAX = 470 LINE_MAX = 510
class IRCArgs(object): class IRCArgs(object):
def __init__(self, args: typing.List[str]): def __init__(self, args: typing.List[str]):
@ -125,42 +125,44 @@ class ParsedLine(object):
return tags, " ".join(pieces).replace("\r", "") return tags, " ".join(pieces).replace("\r", "")
def format(self) -> str: def format(self) -> str:
tags, line = self._format() tags, line = self._format()
line, _ = self._newline_truncate(line)
if tags: if tags:
return "%s %s" % (tags, line) return "%s %s" % (tags, line)
else: else:
return line return line
def _newline_truncate(self, line: str) -> typing.Tuple[str, str]: class SendableLine(ParsedLine):
line, sep, overflow = line.partition("\n") def __init__(self, command: str, args: typing.List[str],
return (line, overflow) margin: int=0, tags: typing.Dict[str, str]=None):
def _line_max(self, hostmask: str, margin: int) -> int: ParsedLine.__init__(self, command, args, None, tags)
return LINE_MAX-len((":%s " % hostmask).encode("utf8"))-margin self._margin = margin
def truncate(self, hostmask: str, margin: int=0) -> typing.Tuple[str, str]:
valid_bytes = b""
valid_index = -1
line_max = self._line_max(hostmask, margin) def push_last(self, arg: str, extra_margin: int=0,
human_trunc: bool=False) -> typing.Optional[str]:
last_arg = self.args[-1]
tags, line = self._format()
n = len(line.encode("utf8")) # get length of current line
n += self._margin # margin used for :hostmask
n += 1 # +1 for space on new arg
if " " in arg and not " " in last_arg:
n += 1 # +1 for colon on new arg
n += extra_margin # used for things like (more ...)
tags_formatted, line_formatted = self._format() overflow: typing.Optional[str] = None
for i, char in enumerate(line_formatted):
encoded_char = char.encode("utf8")
if (len(valid_bytes)+len(encoded_char) > line_max
or encoded_char == b"\n"):
break
else:
valid_bytes += encoded_char
valid_index = i
valid_index += 1
valid = line_formatted[:valid_index] if (n+len(arg.encode("utf8"))) > LINE_MAX:
if tags_formatted: for i, char in enumerate(arg):
valid = "%s %s" % (tags_formatted, valid) n += len(char.encode("utf8"))
overflow = line_formatted[valid_index:] if n > LINE_MAX:
if overflow and overflow[0] == "\n": arg, overflow = arg[:i], arg[i:]
overflow = overflow[1:] if human_trunc and not overflow[0] == " ":
new_arg, sep, new_overflow = arg.rpartition(" ")
return valid, overflow if sep:
arg = new_arg
overflow = new_overflow+overflow
break
if arg:
self.args[-1] = last_arg+arg
return overflow
def parse_line(line: str) -> ParsedLine: def parse_line(line: str) -> ParsedLine:
tags = {} # type: typing.Dict[str, typing.Any] tags = {} # type: typing.Dict[str, typing.Any]
@ -220,7 +222,7 @@ class SentLine(IRCObject.Object):
return self._for_wire() return self._for_wire()
def _for_wire(self) -> str: def _for_wire(self) -> str:
return self.parsed_line.truncate(self._hostmask)[0] return str(self.parsed_line)
def for_wire(self) -> bytes: def for_wire(self) -> bytes:
return b"%s\r\n" % self._for_wire().encode("utf8") return b"%s\r\n" % self._for_wire().encode("utf8")

View file

@ -80,6 +80,11 @@ class Server(IRCObject.Object):
def hostmask(self): def hostmask(self):
return "%s!%s@%s" % (self.nickname, self.username, self.hostname) return "%s!%s@%s" % (self.nickname, self.username, self.hostname)
def new_line(self, command: str, args: typing.List[str]=None,
tags: typing.Dict[str, str]=None) -> IRCLine.SendableLine:
return IRCLine.SendableLine(command, args or [],
len((":%s " % self.hostmask()).encode("utf8")), tags)
def connect(self): def connect(self):
self.socket = IRCSocket.Socket( self.socket = IRCSocket.Socket(
self.bot.log, self.bot.log,

View file

@ -8,8 +8,8 @@ COMMAND_METHOD = "command-method"
COMMAND_METHODS = ["PRIVMSG", "NOTICE"] COMMAND_METHODS = ["PRIVMSG", "NOTICE"]
STR_MORE = " (more...)" STR_MORE = " (more...)"
STR_CONTINUED = "(...continued) "
STR_MORE_LEN = len(STR_MORE.encode("utf8")) STR_MORE_LEN = len(STR_MORE.encode("utf8"))
STR_CONTINUED = "(...continued)"
WORD_BOUNDARIES = [" "] WORD_BOUNDARIES = [" "]
NON_ALPHANUMERIC = [char for char in string.printable if not char.isalnum()] NON_ALPHANUMERIC = [char for char in string.printable if not char.isalnum()]
@ -237,29 +237,25 @@ class Module(ModuleManager.BaseModule):
color = utils.consts.RED color = utils.consts.RED
line_str = obj.pop() line_str = obj.pop()
prefix = ""
if obj.prefix: if obj.prefix:
line_str = "[%s] %s" % ( prefix = "[%s] " % utils.irc.color(obj.prefix, color)
utils.irc.color(obj.prefix, color), line_str) if obj._overflowed:
prefix = "%s%s" % (prefix, STR_CONTINUED)
method = self._command_method(server, target, is_channel) method = self._command_method(server, target, is_channel)
if not method in ["PRIVMSG", "NOTICE"]: if not method in ["PRIVMSG", "NOTICE"]:
raise ValueError("Unknown command-method '%s'" % method) raise ValueError("Unknown command-method '%s'" % method)
line = IRCLine.ParsedLine(method, [target_str, line_str], line = server.new_line(method, [target_str, prefix], tags=tags)
tags=tags)
valid, trunc = line.truncate(server.hostmask(), overflow = line.push_last(line_str, human_trunc=True,
margin=STR_MORE_LEN) extra_margin=STR_MORE_LEN)
if overflow:
line.push_last(STR_MORE)
obj.insert(overflow)
obj._overflowed = True
if trunc:
if not trunc[0] in WORD_BOUNDARIES:
for boundary in WORD_BOUNDARIES:
left, *right = valid.rsplit(boundary, 1)
if right:
valid = left
trunc = right[0]+trunc
obj.insert("%s %s" % (STR_CONTINUED, trunc))
valid = valid+STR_MORE
line = IRCLine.parse_line(valid)
if obj._assured: if obj._assured:
line.assure() line.assure()
server.send(line) server.send(line)

View file

@ -6,6 +6,13 @@ class StdOut(object):
self.prefix = prefix self.prefix = prefix
self._lines = [] self._lines = []
self._assured = False self._assured = False
self._overflowed = False
def copy_from(self, other):
self.prefix = other.prefix
self._lines = other._lines
self._assured = other._assured
self._overflowed = other._overflowed
def assure(self): def assure(self):
self._assured = True self._assured = True

View file

@ -20,4 +20,4 @@ class Module(ModuleManager.BaseModule):
def more(self, event): def more(self, event):
last_stdout = event["target"]._last_stdout last_stdout = event["target"]._last_stdout
if last_stdout and last_stdout.has_text(): if last_stdout and last_stdout.has_text():
event["stdout"].write_lines(last_stdout.get_all()) event["stdout"].copy_from(last_stdout)