diff --git a/src/Control.py b/src/Control.py new file mode 100644 index 00000000..2fe742cc --- /dev/null +++ b/src/Control.py @@ -0,0 +1,81 @@ +import json, os, socket, typing +from src import EventManager, PollSource + +class ControlClient(object): + def __init__(self, sock: socket.socket): + self._socket = sock + self._read_buffer = b"" + self._write_buffer = b"" + + def fileno(self) -> int: + return self._socket.fileno() + + def read_lines(self) -> typing.List[str]: + data = self._socket.recv(2048) + if not data: + return [] + lines = (self._read_buffer+data).split(b"\n") + lines = [line.strip(b"\r") for line in lines] + self._read_buffer = lines.pop(-1) + return [line.decode("utf8") for line in lines] + + def write_line(self, line: str): + self._write_buffer += ("%s\n" % line).encode("utf8") + def _send(self): + sent = self._socket.send(self._write_buffer) + self._write_buffer = self._write_buffer[sent:] + def writeable(self) -> bool: + return bool(self._write_buffer) + + def disconnect(self): + try: + self._socket.shutdown(socket.SHUT_RDWR) + except: + pass + try: + self._socket.close() + except: + pass + + +class Control(PollSource.PollSource): + def __init__(self, events: EventManager.Events, database_location): + self._socket_location = "%s.sock" % database_location + self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self._clients = {} + + def bind(self): + if os.path.exists(self._socket_location): + os.remove(self._socket_location) + self._socket.bind(self._socket_location) + self._socket.listen(1) + + def get_readables(self) -> typing.List[int]: + return [self._socket.fileno()]+list(self._clients.keys()) + def get_writables(self) -> typing.List[int]: + return [f for f, c in self._clients.items() if c.writeable()] + + def is_readable(self, fileno: int): + if fileno == self._socket.fileno(): + client, address = self._socket.accept() + self._clients[client.fileno()] = ControlClient(client) + elif fileno in self._clients: + client = self._clients[fileno] + lines = client.read_lines() + if not lines: + client.disconnect() + del self._clients[fileno] + else: + for line in lines: + response = self._parse_line(client, line) + client.write_line(response) + def is_writeable(self, fileno: int): + self._clients[fileno]._send() + + def _parse_line(self, client: ControlClient, line: str): + version, _, id = line.partition(" ") + id, _, data_str = id.partition(" ") + if version == "0.1": +# data = json.loads(data_str) + response = {"action": "ack"} + return "0.1 %s %s" % (id, json.dumps(response)) diff --git a/start.py b/start.py index 270c021d..fbd17b23 100755 --- a/start.py +++ b/start.py @@ -7,7 +7,7 @@ if sys.version_info < (3, 6): sys.exit(1) import atexit, argparse, faulthandler, os, platform, time -from src import Cache, Config, Database, EventManager, Exports, IRCBot +from src import Cache, Config, Control, Database, EventManager, Exports, IRCBot from src import LockFile, Logging, ModuleManager, Timers, utils faulthandler.enable() @@ -96,6 +96,9 @@ events = EventManager.EventRoot(log).wrap() exports = Exports.Exports() timers = Timers.Timers(database, events, log) +control = Control.Control(events, args.database) +control.bind() + module_directories = [os.path.join(directory, "modules")] if args.external: module_directories.append(os.path.abspath(args.external)) @@ -111,6 +114,8 @@ bot.add_poll_hook(cache) bot.add_poll_hook(lock_file) bot.add_poll_hook(timers) +bot.add_poll_source(control) + if args.module: definition = modules.find_module(args.module) module = modules.load_module(bot, definition)