Add a way to track non-IRC sockets within the main epoll loop; use this for a
unix domain control socket!
This commit is contained in:
parent
1fa66eebc6
commit
0794a5173a
4 changed files with 139 additions and 16 deletions
36
src/ControlSocket.py
Normal file
36
src/ControlSocket.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
import os, socket
|
||||||
|
from src import Socket
|
||||||
|
|
||||||
|
class ControlSocket(object):
|
||||||
|
def __init__(self, bot):
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
location = bot.config["control-socket"]
|
||||||
|
if os.path.exists(location):
|
||||||
|
os.unlink(location)
|
||||||
|
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||||
|
self.socket.bind(location)
|
||||||
|
self.socket.listen()
|
||||||
|
self.connected = True
|
||||||
|
|
||||||
|
def fileno(self):
|
||||||
|
return self.socket.fileno()
|
||||||
|
def waiting_send(self):
|
||||||
|
return False
|
||||||
|
def _send(self):
|
||||||
|
pass
|
||||||
|
def read(self):
|
||||||
|
client, addr = self.socket.accept()
|
||||||
|
self.bot.add_socket(Socket.Socket(client, self.on_read))
|
||||||
|
return []
|
||||||
|
def parse_data(self, data):
|
||||||
|
command = data.split(" ", 1)[0].upper()
|
||||||
|
if command == "TRIGGER":
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown control socket command: '%s'" %
|
||||||
|
command)
|
||||||
|
|
||||||
|
def on_read(self, sock, data):
|
||||||
|
data = data.strip("\r\n")
|
||||||
|
print(data)
|
|
@ -1,5 +1,6 @@
|
||||||
import os, select, sys, threading, time, traceback, uuid
|
import os, select, sys, threading, time, traceback, uuid
|
||||||
from . import EventManager, Exports, IRCServer, Logging, ModuleManager
|
from src import ControlSocket, EventManager, Exports, IRCServer, Logging
|
||||||
|
from src import ModuleManager, utils
|
||||||
|
|
||||||
class Bot(object):
|
class Bot(object):
|
||||||
def __init__(self, directory, args, cache, config, database, events,
|
def __init__(self, directory, args, cache, config, database, events,
|
||||||
|
@ -15,13 +16,16 @@ class Bot(object):
|
||||||
self.modules = modules
|
self.modules = modules
|
||||||
self.timers = timers
|
self.timers = timers
|
||||||
|
|
||||||
events.on("timer.reconnect").hook(self.reconnect)
|
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
self.servers = {}
|
|
||||||
self.running = True
|
self.running = True
|
||||||
self.poll = select.epoll()
|
self.poll = select.epoll()
|
||||||
|
|
||||||
|
self.servers = {}
|
||||||
|
self.other_sockets = {}
|
||||||
|
self.control_socket = ControlSocket.ControlSocket(self)
|
||||||
|
self.add_socket(self.control_socket)
|
||||||
|
|
||||||
def add_server(self, server_id, connect=True):
|
def add_server(self, server_id, connect=True):
|
||||||
(_, alias, hostname, port, password, ipv4, tls, bindhost, nickname,
|
(_, alias, hostname, port, password, ipv4, tls, bindhost, nickname,
|
||||||
username, realname) = self.database.servers.get(server_id)
|
username, realname) = self.database.servers.get(server_id)
|
||||||
|
@ -36,6 +40,14 @@ class Bot(object):
|
||||||
self.connect(new_server)
|
self.connect(new_server)
|
||||||
return new_server
|
return new_server
|
||||||
|
|
||||||
|
def add_socket(self, sock):
|
||||||
|
self.other_sockets[sock.fileno()] = sock
|
||||||
|
self.poll.register(sock.fileno(), select.EPOLLIN)
|
||||||
|
|
||||||
|
def remove_socket(self, sock):
|
||||||
|
del self.other_sockets[sock.fileno()]
|
||||||
|
self.poll.unregister(sock.fileno())
|
||||||
|
|
||||||
def get_server(self, id):
|
def get_server(self, id):
|
||||||
for server in self.servers.values():
|
for server in self.servers.values():
|
||||||
if server.id == id:
|
if server.id == id:
|
||||||
|
@ -101,6 +113,7 @@ class Bot(object):
|
||||||
pass
|
pass
|
||||||
del self.servers[server.fileno()]
|
del self.servers[server.fileno()]
|
||||||
|
|
||||||
|
@utils.hook("timer.reconnect")
|
||||||
def reconnect(self, event):
|
def reconnect(self, event):
|
||||||
server = self.add_server(event["server_id"], False)
|
server = self.add_server(event["server_id"], False)
|
||||||
if self.connect(server):
|
if self.connect(server):
|
||||||
|
@ -128,19 +141,30 @@ class Bot(object):
|
||||||
self.cache.expire()
|
self.cache.expire()
|
||||||
|
|
||||||
for fd, event in events:
|
for fd, event in events:
|
||||||
|
sock = None
|
||||||
|
irc = False
|
||||||
if fd in self.servers:
|
if fd in self.servers:
|
||||||
server = self.servers[fd]
|
sock = self.servers[fd]
|
||||||
|
irc = True
|
||||||
|
elif fd in self.other_sockets:
|
||||||
|
sock = self.other_sockets[fd]
|
||||||
|
|
||||||
|
if sock:
|
||||||
if event & select.EPOLLIN:
|
if event & select.EPOLLIN:
|
||||||
lines = server.read()
|
data = sock.read()
|
||||||
for line in lines:
|
if data == None:
|
||||||
self.log.debug("%s (raw) | %s", [str(server), line])
|
sock.disconnect()
|
||||||
server.parse_line(line)
|
for piece in data:
|
||||||
|
if irc:
|
||||||
|
self.log.debug("%s (raw) | %s",
|
||||||
|
[str(sock), data])
|
||||||
|
sock.parse_data(piece)
|
||||||
elif event & select.EPOLLOUT:
|
elif event & select.EPOLLOUT:
|
||||||
server._send()
|
sock._send()
|
||||||
self.register_read(server)
|
self.register_read(sock)
|
||||||
elif event & select.EPULLHUP:
|
elif event & select.EPULLHUP:
|
||||||
print("hangup")
|
print("hangup")
|
||||||
server.disconnect()
|
sock.disconnect()
|
||||||
|
|
||||||
for server in list(self.servers.values()):
|
for server in list(self.servers.values()):
|
||||||
if server.read_timed_out():
|
if server.read_timed_out():
|
||||||
|
@ -160,4 +184,11 @@ class Bot(object):
|
||||||
str(server), reconnect_delay))
|
str(server), reconnect_delay))
|
||||||
elif server.waiting_send() and server.throttle_done():
|
elif server.waiting_send() and server.throttle_done():
|
||||||
self.register_both(server)
|
self.register_both(server)
|
||||||
|
|
||||||
|
for sock in list(self.other_sockets.values()):
|
||||||
|
if not sock.connected:
|
||||||
|
self.remove_socket(sock)
|
||||||
|
elif sock.waiting_send():
|
||||||
|
self.register_both(sock)
|
||||||
|
|
||||||
self.lock.release()
|
self.lock.release()
|
||||||
|
|
|
@ -199,7 +199,7 @@ class Server(IRCObject.Object):
|
||||||
for user in channel.users:
|
for user in channel.users:
|
||||||
user.part_channel(channel)
|
user.part_channel(channel)
|
||||||
del self.channels[channel.name]
|
del self.channels[channel.name]
|
||||||
def parse_line(self, line):
|
def parse_data(self, line):
|
||||||
if not line:
|
if not line:
|
||||||
return
|
return
|
||||||
self.events.on("raw").call(server=self, line=line)
|
self.events.on("raw").call(server=self, line=line)
|
||||||
|
@ -212,16 +212,22 @@ class Server(IRCObject.Object):
|
||||||
def read(self):
|
def read(self):
|
||||||
data = b""
|
data = b""
|
||||||
try:
|
try:
|
||||||
data = self.read_buffer + self.socket.recv(4096)
|
data = self.socket.recv(4096)
|
||||||
except (ConnectionResetError, socket.timeout):
|
except (ConnectionResetError, socket.timeout):
|
||||||
self.disconnect()
|
self.disconnect()
|
||||||
return []
|
return None
|
||||||
|
if not data:
|
||||||
|
self.disconnect()
|
||||||
|
return None
|
||||||
|
data = self.read_buffer+data
|
||||||
self.read_buffer = b""
|
self.read_buffer = b""
|
||||||
|
|
||||||
data_lines = [line.strip(b"\r") for line in data.split(b"\n")]
|
data_lines = [line.strip(b"\r") for line in data.split(b"\n")]
|
||||||
if data_lines[-1]:
|
if data_lines[-1]:
|
||||||
self.read_buffer = data_lines[-1]
|
self.read_buffer = data_lines[-1]
|
||||||
data_lines.pop(-1)
|
data_lines.pop(-1)
|
||||||
decoded_lines = []
|
decoded_lines = []
|
||||||
|
|
||||||
for line in data_lines:
|
for line in data_lines:
|
||||||
try:
|
try:
|
||||||
line = line.decode(self.get_setting(
|
line = line.decode(self.get_setting(
|
||||||
|
@ -233,8 +239,7 @@ class Server(IRCObject.Object):
|
||||||
except:
|
except:
|
||||||
continue
|
continue
|
||||||
decoded_lines.append(line)
|
decoded_lines.append(line)
|
||||||
if not decoded_lines:
|
|
||||||
self.disconnect()
|
|
||||||
self.last_read = time.monotonic()
|
self.last_read = time.monotonic()
|
||||||
self.ping_sent = False
|
self.ping_sent = False
|
||||||
return decoded_lines
|
return decoded_lines
|
||||||
|
|
51
src/Socket.py
Normal file
51
src/Socket.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
|
||||||
|
|
||||||
|
class Socket(object):
|
||||||
|
def __init__(self, socket, on_read, encoding="utf8"):
|
||||||
|
self.socket = socket
|
||||||
|
self._on_read = on_read
|
||||||
|
self.encoding = encoding
|
||||||
|
|
||||||
|
self._write_buffer = b""
|
||||||
|
self._read_buffer = b""
|
||||||
|
self.delimiter = None
|
||||||
|
self.length = None
|
||||||
|
self.connected = True
|
||||||
|
|
||||||
|
def fileno(self):
|
||||||
|
return self.socket.fileno()
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
self.connected = False
|
||||||
|
|
||||||
|
def _decode(self, s):
|
||||||
|
return s.decode(self.encoding) if self.encoding else s
|
||||||
|
def _encode(self, s):
|
||||||
|
return s.encode(self.encoding) if self.encoding else s
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
data = self.socket.recv(1024)
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = self._read_buffer+data
|
||||||
|
self._read_buffer = b""
|
||||||
|
if not self.delimiter == None:
|
||||||
|
data_split = data.split(delimiter)
|
||||||
|
if data_split[-1]:
|
||||||
|
self._read_buffer = data_split.pop(-1)
|
||||||
|
return [self._decode(data) for data in data_split]
|
||||||
|
return [data.decode(self.encoding)]
|
||||||
|
|
||||||
|
def parse_data(self, data):
|
||||||
|
self._on_read(self, data)
|
||||||
|
|
||||||
|
def send(self, data):
|
||||||
|
self._write_buffer += self._encode(data)
|
||||||
|
|
||||||
|
def _send(self):
|
||||||
|
self._write_buffer = self._write_buffer[self.socket.send(
|
||||||
|
self._write_buffer):]
|
||||||
|
|
||||||
|
def waiting_send(self):
|
||||||
|
return bool(len(self._write_buffer))
|
Loading…
Reference in a new issue