Add a way to track non-IRC sockets within the main epoll loop; use this for a

unix domain control socket!
This commit is contained in:
jesopo 2018-10-06 15:37:05 +01:00
parent 1fa66eebc6
commit 0794a5173a
4 changed files with 139 additions and 16 deletions

36
src/ControlSocket.py Normal file
View file

@ -0,0 +1,36 @@
import os, socket
from src import Socket
class ControlSocket(object):
def __init__(self, bot):
self.bot = bot
location = bot.config["control-socket"]
if os.path.exists(location):
os.unlink(location)
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.socket.bind(location)
self.socket.listen()
self.connected = True
def fileno(self):
return self.socket.fileno()
def waiting_send(self):
return False
def _send(self):
pass
def read(self):
client, addr = self.socket.accept()
self.bot.add_socket(Socket.Socket(client, self.on_read))
return []
def parse_data(self, data):
command = data.split(" ", 1)[0].upper()
if command == "TRIGGER":
pass
else:
raise ValueError("unknown control socket command: '%s'" %
command)
def on_read(self, sock, data):
data = data.strip("\r\n")
print(data)

View file

@ -1,5 +1,6 @@
import os, select, sys, threading, time, traceback, uuid import os, select, sys, threading, time, traceback, uuid
from . import EventManager, Exports, IRCServer, Logging, ModuleManager from src import ControlSocket, EventManager, Exports, IRCServer, Logging
from src import ModuleManager, utils
class Bot(object): class Bot(object):
def __init__(self, directory, args, cache, config, database, events, def __init__(self, directory, args, cache, config, database, events,
@ -15,13 +16,16 @@ class Bot(object):
self.modules = modules self.modules = modules
self.timers = timers self.timers = timers
events.on("timer.reconnect").hook(self.reconnect)
self.start_time = time.time() self.start_time = time.time()
self.lock = threading.Lock() self.lock = threading.Lock()
self.servers = {}
self.running = True self.running = True
self.poll = select.epoll() self.poll = select.epoll()
self.servers = {}
self.other_sockets = {}
self.control_socket = ControlSocket.ControlSocket(self)
self.add_socket(self.control_socket)
def add_server(self, server_id, connect=True): def add_server(self, server_id, connect=True):
(_, alias, hostname, port, password, ipv4, tls, bindhost, nickname, (_, alias, hostname, port, password, ipv4, tls, bindhost, nickname,
username, realname) = self.database.servers.get(server_id) username, realname) = self.database.servers.get(server_id)
@ -36,6 +40,14 @@ class Bot(object):
self.connect(new_server) self.connect(new_server)
return new_server return new_server
def add_socket(self, sock):
self.other_sockets[sock.fileno()] = sock
self.poll.register(sock.fileno(), select.EPOLLIN)
def remove_socket(self, sock):
del self.other_sockets[sock.fileno()]
self.poll.unregister(sock.fileno())
def get_server(self, id): def get_server(self, id):
for server in self.servers.values(): for server in self.servers.values():
if server.id == id: if server.id == id:
@ -101,6 +113,7 @@ class Bot(object):
pass pass
del self.servers[server.fileno()] del self.servers[server.fileno()]
@utils.hook("timer.reconnect")
def reconnect(self, event): def reconnect(self, event):
server = self.add_server(event["server_id"], False) server = self.add_server(event["server_id"], False)
if self.connect(server): if self.connect(server):
@ -128,19 +141,30 @@ class Bot(object):
self.cache.expire() self.cache.expire()
for fd, event in events: for fd, event in events:
sock = None
irc = False
if fd in self.servers: if fd in self.servers:
server = self.servers[fd] sock = self.servers[fd]
irc = True
elif fd in self.other_sockets:
sock = self.other_sockets[fd]
if sock:
if event & select.EPOLLIN: if event & select.EPOLLIN:
lines = server.read() data = sock.read()
for line in lines: if data == None:
self.log.debug("%s (raw) | %s", [str(server), line]) sock.disconnect()
server.parse_line(line) for piece in data:
if irc:
self.log.debug("%s (raw) | %s",
[str(sock), data])
sock.parse_data(piece)
elif event & select.EPOLLOUT: elif event & select.EPOLLOUT:
server._send() sock._send()
self.register_read(server) self.register_read(sock)
elif event & select.EPULLHUP: elif event & select.EPULLHUP:
print("hangup") print("hangup")
server.disconnect() sock.disconnect()
for server in list(self.servers.values()): for server in list(self.servers.values()):
if server.read_timed_out(): if server.read_timed_out():
@ -160,4 +184,11 @@ class Bot(object):
str(server), reconnect_delay)) str(server), reconnect_delay))
elif server.waiting_send() and server.throttle_done(): elif server.waiting_send() and server.throttle_done():
self.register_both(server) self.register_both(server)
for sock in list(self.other_sockets.values()):
if not sock.connected:
self.remove_socket(sock)
elif sock.waiting_send():
self.register_both(sock)
self.lock.release() self.lock.release()

View file

@ -199,7 +199,7 @@ class Server(IRCObject.Object):
for user in channel.users: for user in channel.users:
user.part_channel(channel) user.part_channel(channel)
del self.channels[channel.name] del self.channels[channel.name]
def parse_line(self, line): def parse_data(self, line):
if not line: if not line:
return return
self.events.on("raw").call(server=self, line=line) self.events.on("raw").call(server=self, line=line)
@ -212,16 +212,22 @@ class Server(IRCObject.Object):
def read(self): def read(self):
data = b"" data = b""
try: try:
data = self.read_buffer + self.socket.recv(4096) data = self.socket.recv(4096)
except (ConnectionResetError, socket.timeout): except (ConnectionResetError, socket.timeout):
self.disconnect() self.disconnect()
return [] return None
if not data:
self.disconnect()
return None
data = self.read_buffer+data
self.read_buffer = b"" self.read_buffer = b""
data_lines = [line.strip(b"\r") for line in data.split(b"\n")] data_lines = [line.strip(b"\r") for line in data.split(b"\n")]
if data_lines[-1]: if data_lines[-1]:
self.read_buffer = data_lines[-1] self.read_buffer = data_lines[-1]
data_lines.pop(-1) data_lines.pop(-1)
decoded_lines = [] decoded_lines = []
for line in data_lines: for line in data_lines:
try: try:
line = line.decode(self.get_setting( line = line.decode(self.get_setting(
@ -233,8 +239,7 @@ class Server(IRCObject.Object):
except: except:
continue continue
decoded_lines.append(line) decoded_lines.append(line)
if not decoded_lines:
self.disconnect()
self.last_read = time.monotonic() self.last_read = time.monotonic()
self.ping_sent = False self.ping_sent = False
return decoded_lines return decoded_lines

51
src/Socket.py Normal file
View file

@ -0,0 +1,51 @@
class Socket(object):
def __init__(self, socket, on_read, encoding="utf8"):
self.socket = socket
self._on_read = on_read
self.encoding = encoding
self._write_buffer = b""
self._read_buffer = b""
self.delimiter = None
self.length = None
self.connected = True
def fileno(self):
return self.socket.fileno()
def disconnect(self):
self.connected = False
def _decode(self, s):
return s.decode(self.encoding) if self.encoding else s
def _encode(self, s):
return s.encode(self.encoding) if self.encoding else s
def read(self):
data = self.socket.recv(1024)
if not data:
return None
data = self._read_buffer+data
self._read_buffer = b""
if not self.delimiter == None:
data_split = data.split(delimiter)
if data_split[-1]:
self._read_buffer = data_split.pop(-1)
return [self._decode(data) for data in data_split]
return [data.decode(self.encoding)]
def parse_data(self, data):
self._on_read(self, data)
def send(self, data):
self._write_buffer += self._encode(data)
def _send(self):
self._write_buffer = self._write_buffer[self.socket.send(
self._write_buffer):]
def waiting_send(self):
return bool(len(self._write_buffer))