first draft of infrastructure for unix domain control socket
This commit is contained in:
parent
daeb37226a
commit
cefde48e42
2 changed files with 87 additions and 1 deletions
81
src/Control.py
Normal file
81
src/Control.py
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
import json, os, socket, typing
|
||||||
|
from src import EventManager, PollSource
|
||||||
|
|
||||||
|
class ControlClient(object):
|
||||||
|
def __init__(self, sock: socket.socket):
|
||||||
|
self._socket = sock
|
||||||
|
self._read_buffer = b""
|
||||||
|
self._write_buffer = b""
|
||||||
|
|
||||||
|
def fileno(self) -> int:
|
||||||
|
return self._socket.fileno()
|
||||||
|
|
||||||
|
def read_lines(self) -> typing.List[str]:
|
||||||
|
data = self._socket.recv(2048)
|
||||||
|
if not data:
|
||||||
|
return []
|
||||||
|
lines = (self._read_buffer+data).split(b"\n")
|
||||||
|
lines = [line.strip(b"\r") for line in lines]
|
||||||
|
self._read_buffer = lines.pop(-1)
|
||||||
|
return [line.decode("utf8") for line in lines]
|
||||||
|
|
||||||
|
def write_line(self, line: str):
|
||||||
|
self._write_buffer += ("%s\n" % line).encode("utf8")
|
||||||
|
def _send(self):
|
||||||
|
sent = self._socket.send(self._write_buffer)
|
||||||
|
self._write_buffer = self._write_buffer[sent:]
|
||||||
|
def writeable(self) -> bool:
|
||||||
|
return bool(self._write_buffer)
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
try:
|
||||||
|
self._socket.shutdown(socket.SHUT_RDWR)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
self._socket.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Control(PollSource.PollSource):
|
||||||
|
def __init__(self, events: EventManager.Events, database_location):
|
||||||
|
self._socket_location = "%s.sock" % database_location
|
||||||
|
self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||||
|
self._clients = {}
|
||||||
|
|
||||||
|
def bind(self):
|
||||||
|
if os.path.exists(self._socket_location):
|
||||||
|
os.remove(self._socket_location)
|
||||||
|
self._socket.bind(self._socket_location)
|
||||||
|
self._socket.listen(1)
|
||||||
|
|
||||||
|
def get_readables(self) -> typing.List[int]:
|
||||||
|
return [self._socket.fileno()]+list(self._clients.keys())
|
||||||
|
def get_writables(self) -> typing.List[int]:
|
||||||
|
return [f for f, c in self._clients.items() if c.writeable()]
|
||||||
|
|
||||||
|
def is_readable(self, fileno: int):
|
||||||
|
if fileno == self._socket.fileno():
|
||||||
|
client, address = self._socket.accept()
|
||||||
|
self._clients[client.fileno()] = ControlClient(client)
|
||||||
|
elif fileno in self._clients:
|
||||||
|
client = self._clients[fileno]
|
||||||
|
lines = client.read_lines()
|
||||||
|
if not lines:
|
||||||
|
client.disconnect()
|
||||||
|
del self._clients[fileno]
|
||||||
|
else:
|
||||||
|
for line in lines:
|
||||||
|
response = self._parse_line(client, line)
|
||||||
|
client.write_line(response)
|
||||||
|
def is_writeable(self, fileno: int):
|
||||||
|
self._clients[fileno]._send()
|
||||||
|
|
||||||
|
def _parse_line(self, client: ControlClient, line: str):
|
||||||
|
version, _, id = line.partition(" ")
|
||||||
|
id, _, data_str = id.partition(" ")
|
||||||
|
if version == "0.1":
|
||||||
|
# data = json.loads(data_str)
|
||||||
|
response = {"action": "ack"}
|
||||||
|
return "0.1 %s %s" % (id, json.dumps(response))
|
7
start.py
7
start.py
|
@ -7,7 +7,7 @@ if sys.version_info < (3, 6):
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
import atexit, argparse, faulthandler, os, platform, time
|
import atexit, argparse, faulthandler, os, platform, time
|
||||||
from src import Cache, Config, Database, EventManager, Exports, IRCBot
|
from src import Cache, Config, Control, Database, EventManager, Exports, IRCBot
|
||||||
from src import LockFile, Logging, ModuleManager, Timers, utils
|
from src import LockFile, Logging, ModuleManager, Timers, utils
|
||||||
|
|
||||||
faulthandler.enable()
|
faulthandler.enable()
|
||||||
|
@ -96,6 +96,9 @@ events = EventManager.EventRoot(log).wrap()
|
||||||
exports = Exports.Exports()
|
exports = Exports.Exports()
|
||||||
timers = Timers.Timers(database, events, log)
|
timers = Timers.Timers(database, events, log)
|
||||||
|
|
||||||
|
control = Control.Control(events, args.database)
|
||||||
|
control.bind()
|
||||||
|
|
||||||
module_directories = [os.path.join(directory, "modules")]
|
module_directories = [os.path.join(directory, "modules")]
|
||||||
if args.external:
|
if args.external:
|
||||||
module_directories.append(os.path.abspath(args.external))
|
module_directories.append(os.path.abspath(args.external))
|
||||||
|
@ -111,6 +114,8 @@ bot.add_poll_hook(cache)
|
||||||
bot.add_poll_hook(lock_file)
|
bot.add_poll_hook(lock_file)
|
||||||
bot.add_poll_hook(timers)
|
bot.add_poll_hook(timers)
|
||||||
|
|
||||||
|
bot.add_poll_source(control)
|
||||||
|
|
||||||
if args.module:
|
if args.module:
|
||||||
definition = modules.find_module(args.module)
|
definition = modules.find_module(args.module)
|
||||||
module = modules.load_module(bot, definition)
|
module = modules.load_module(bot, definition)
|
||||||
|
|
Loading…
Reference in a new issue