first draft of infrastructure for unix domain control socket

This commit is contained in:
jesopo 2019-10-11 14:00:26 +01:00
parent daeb37226a
commit cefde48e42
2 changed files with 87 additions and 1 deletions

81
src/Control.py Normal file
View file

@ -0,0 +1,81 @@
import json, os, socket, typing
from src import EventManager, PollSource
class ControlClient(object):
def __init__(self, sock: socket.socket):
self._socket = sock
self._read_buffer = b""
self._write_buffer = b""
def fileno(self) -> int:
return self._socket.fileno()
def read_lines(self) -> typing.List[str]:
data = self._socket.recv(2048)
if not data:
return []
lines = (self._read_buffer+data).split(b"\n")
lines = [line.strip(b"\r") for line in lines]
self._read_buffer = lines.pop(-1)
return [line.decode("utf8") for line in lines]
def write_line(self, line: str):
self._write_buffer += ("%s\n" % line).encode("utf8")
def _send(self):
sent = self._socket.send(self._write_buffer)
self._write_buffer = self._write_buffer[sent:]
def writeable(self) -> bool:
return bool(self._write_buffer)
def disconnect(self):
try:
self._socket.shutdown(socket.SHUT_RDWR)
except:
pass
try:
self._socket.close()
except:
pass
class Control(PollSource.PollSource):
def __init__(self, events: EventManager.Events, database_location):
self._socket_location = "%s.sock" % database_location
self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self._clients = {}
def bind(self):
if os.path.exists(self._socket_location):
os.remove(self._socket_location)
self._socket.bind(self._socket_location)
self._socket.listen(1)
def get_readables(self) -> typing.List[int]:
return [self._socket.fileno()]+list(self._clients.keys())
def get_writables(self) -> typing.List[int]:
return [f for f, c in self._clients.items() if c.writeable()]
def is_readable(self, fileno: int):
if fileno == self._socket.fileno():
client, address = self._socket.accept()
self._clients[client.fileno()] = ControlClient(client)
elif fileno in self._clients:
client = self._clients[fileno]
lines = client.read_lines()
if not lines:
client.disconnect()
del self._clients[fileno]
else:
for line in lines:
response = self._parse_line(client, line)
client.write_line(response)
def is_writeable(self, fileno: int):
self._clients[fileno]._send()
def _parse_line(self, client: ControlClient, line: str):
version, _, id = line.partition(" ")
id, _, data_str = id.partition(" ")
if version == "0.1":
# data = json.loads(data_str)
response = {"action": "ack"}
return "0.1 %s %s" % (id, json.dumps(response))

View file

@ -7,7 +7,7 @@ if sys.version_info < (3, 6):
sys.exit(1)
import atexit, argparse, faulthandler, os, platform, time
from src import Cache, Config, Database, EventManager, Exports, IRCBot
from src import Cache, Config, Control, Database, EventManager, Exports, IRCBot
from src import LockFile, Logging, ModuleManager, Timers, utils
faulthandler.enable()
@ -96,6 +96,9 @@ events = EventManager.EventRoot(log).wrap()
exports = Exports.Exports()
timers = Timers.Timers(database, events, log)
control = Control.Control(events, args.database)
control.bind()
module_directories = [os.path.join(directory, "modules")]
if args.external:
module_directories.append(os.path.abspath(args.external))
@ -111,6 +114,8 @@ bot.add_poll_hook(cache)
bot.add_poll_hook(lock_file)
bot.add_poll_hook(timers)
bot.add_poll_source(control)
if args.module:
definition = modules.find_module(args.module)
module = modules.load_module(bot, definition)