Add a way to track non-IRC sockets within the main epoll loop; use this for a
unix domain control socket!
This commit is contained in:
parent
1fa66eebc6
commit
0794a5173a
4 changed files with 139 additions and 16 deletions
36
src/ControlSocket.py
Normal file
36
src/ControlSocket.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
import os, socket
|
||||
from src import Socket
|
||||
|
||||
class ControlSocket(object):
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
|
||||
location = bot.config["control-socket"]
|
||||
if os.path.exists(location):
|
||||
os.unlink(location)
|
||||
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
self.socket.bind(location)
|
||||
self.socket.listen()
|
||||
self.connected = True
|
||||
|
||||
def fileno(self):
|
||||
return self.socket.fileno()
|
||||
def waiting_send(self):
|
||||
return False
|
||||
def _send(self):
|
||||
pass
|
||||
def read(self):
|
||||
client, addr = self.socket.accept()
|
||||
self.bot.add_socket(Socket.Socket(client, self.on_read))
|
||||
return []
|
||||
def parse_data(self, data):
|
||||
command = data.split(" ", 1)[0].upper()
|
||||
if command == "TRIGGER":
|
||||
pass
|
||||
else:
|
||||
raise ValueError("unknown control socket command: '%s'" %
|
||||
command)
|
||||
|
||||
def on_read(self, sock, data):
|
||||
data = data.strip("\r\n")
|
||||
print(data)
|
|
@ -1,5 +1,6 @@
|
|||
import os, select, sys, threading, time, traceback, uuid
|
||||
from . import EventManager, Exports, IRCServer, Logging, ModuleManager
|
||||
from src import ControlSocket, EventManager, Exports, IRCServer, Logging
|
||||
from src import ModuleManager, utils
|
||||
|
||||
class Bot(object):
|
||||
def __init__(self, directory, args, cache, config, database, events,
|
||||
|
@ -15,13 +16,16 @@ class Bot(object):
|
|||
self.modules = modules
|
||||
self.timers = timers
|
||||
|
||||
events.on("timer.reconnect").hook(self.reconnect)
|
||||
self.start_time = time.time()
|
||||
self.lock = threading.Lock()
|
||||
self.servers = {}
|
||||
self.running = True
|
||||
self.poll = select.epoll()
|
||||
|
||||
self.servers = {}
|
||||
self.other_sockets = {}
|
||||
self.control_socket = ControlSocket.ControlSocket(self)
|
||||
self.add_socket(self.control_socket)
|
||||
|
||||
def add_server(self, server_id, connect=True):
|
||||
(_, alias, hostname, port, password, ipv4, tls, bindhost, nickname,
|
||||
username, realname) = self.database.servers.get(server_id)
|
||||
|
@ -36,6 +40,14 @@ class Bot(object):
|
|||
self.connect(new_server)
|
||||
return new_server
|
||||
|
||||
def add_socket(self, sock):
|
||||
self.other_sockets[sock.fileno()] = sock
|
||||
self.poll.register(sock.fileno(), select.EPOLLIN)
|
||||
|
||||
def remove_socket(self, sock):
|
||||
del self.other_sockets[sock.fileno()]
|
||||
self.poll.unregister(sock.fileno())
|
||||
|
||||
def get_server(self, id):
|
||||
for server in self.servers.values():
|
||||
if server.id == id:
|
||||
|
@ -101,6 +113,7 @@ class Bot(object):
|
|||
pass
|
||||
del self.servers[server.fileno()]
|
||||
|
||||
@utils.hook("timer.reconnect")
|
||||
def reconnect(self, event):
|
||||
server = self.add_server(event["server_id"], False)
|
||||
if self.connect(server):
|
||||
|
@ -128,19 +141,30 @@ class Bot(object):
|
|||
self.cache.expire()
|
||||
|
||||
for fd, event in events:
|
||||
sock = None
|
||||
irc = False
|
||||
if fd in self.servers:
|
||||
server = self.servers[fd]
|
||||
sock = self.servers[fd]
|
||||
irc = True
|
||||
elif fd in self.other_sockets:
|
||||
sock = self.other_sockets[fd]
|
||||
|
||||
if sock:
|
||||
if event & select.EPOLLIN:
|
||||
lines = server.read()
|
||||
for line in lines:
|
||||
self.log.debug("%s (raw) | %s", [str(server), line])
|
||||
server.parse_line(line)
|
||||
data = sock.read()
|
||||
if data == None:
|
||||
sock.disconnect()
|
||||
for piece in data:
|
||||
if irc:
|
||||
self.log.debug("%s (raw) | %s",
|
||||
[str(sock), data])
|
||||
sock.parse_data(piece)
|
||||
elif event & select.EPOLLOUT:
|
||||
server._send()
|
||||
self.register_read(server)
|
||||
sock._send()
|
||||
self.register_read(sock)
|
||||
elif event & select.EPULLHUP:
|
||||
print("hangup")
|
||||
server.disconnect()
|
||||
sock.disconnect()
|
||||
|
||||
for server in list(self.servers.values()):
|
||||
if server.read_timed_out():
|
||||
|
@ -160,4 +184,11 @@ class Bot(object):
|
|||
str(server), reconnect_delay))
|
||||
elif server.waiting_send() and server.throttle_done():
|
||||
self.register_both(server)
|
||||
|
||||
for sock in list(self.other_sockets.values()):
|
||||
if not sock.connected:
|
||||
self.remove_socket(sock)
|
||||
elif sock.waiting_send():
|
||||
self.register_both(sock)
|
||||
|
||||
self.lock.release()
|
||||
|
|
|
@ -199,7 +199,7 @@ class Server(IRCObject.Object):
|
|||
for user in channel.users:
|
||||
user.part_channel(channel)
|
||||
del self.channels[channel.name]
|
||||
def parse_line(self, line):
|
||||
def parse_data(self, line):
|
||||
if not line:
|
||||
return
|
||||
self.events.on("raw").call(server=self, line=line)
|
||||
|
@ -212,16 +212,22 @@ class Server(IRCObject.Object):
|
|||
def read(self):
|
||||
data = b""
|
||||
try:
|
||||
data = self.read_buffer + self.socket.recv(4096)
|
||||
data = self.socket.recv(4096)
|
||||
except (ConnectionResetError, socket.timeout):
|
||||
self.disconnect()
|
||||
return []
|
||||
return None
|
||||
if not data:
|
||||
self.disconnect()
|
||||
return None
|
||||
data = self.read_buffer+data
|
||||
self.read_buffer = b""
|
||||
|
||||
data_lines = [line.strip(b"\r") for line in data.split(b"\n")]
|
||||
if data_lines[-1]:
|
||||
self.read_buffer = data_lines[-1]
|
||||
data_lines.pop(-1)
|
||||
decoded_lines = []
|
||||
|
||||
for line in data_lines:
|
||||
try:
|
||||
line = line.decode(self.get_setting(
|
||||
|
@ -233,8 +239,7 @@ class Server(IRCObject.Object):
|
|||
except:
|
||||
continue
|
||||
decoded_lines.append(line)
|
||||
if not decoded_lines:
|
||||
self.disconnect()
|
||||
|
||||
self.last_read = time.monotonic()
|
||||
self.ping_sent = False
|
||||
return decoded_lines
|
||||
|
|
51
src/Socket.py
Normal file
51
src/Socket.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
|
||||
|
||||
class Socket(object):
|
||||
def __init__(self, socket, on_read, encoding="utf8"):
|
||||
self.socket = socket
|
||||
self._on_read = on_read
|
||||
self.encoding = encoding
|
||||
|
||||
self._write_buffer = b""
|
||||
self._read_buffer = b""
|
||||
self.delimiter = None
|
||||
self.length = None
|
||||
self.connected = True
|
||||
|
||||
def fileno(self):
|
||||
return self.socket.fileno()
|
||||
|
||||
def disconnect(self):
|
||||
self.connected = False
|
||||
|
||||
def _decode(self, s):
|
||||
return s.decode(self.encoding) if self.encoding else s
|
||||
def _encode(self, s):
|
||||
return s.encode(self.encoding) if self.encoding else s
|
||||
|
||||
def read(self):
|
||||
data = self.socket.recv(1024)
|
||||
if not data:
|
||||
return None
|
||||
|
||||
data = self._read_buffer+data
|
||||
self._read_buffer = b""
|
||||
if not self.delimiter == None:
|
||||
data_split = data.split(delimiter)
|
||||
if data_split[-1]:
|
||||
self._read_buffer = data_split.pop(-1)
|
||||
return [self._decode(data) for data in data_split]
|
||||
return [data.decode(self.encoding)]
|
||||
|
||||
def parse_data(self, data):
|
||||
self._on_read(self, data)
|
||||
|
||||
def send(self, data):
|
||||
self._write_buffer += self._encode(data)
|
||||
|
||||
def _send(self):
|
||||
self._write_buffer = self._write_buffer[self.socket.send(
|
||||
self._write_buffer):]
|
||||
|
||||
def waiting_send(self):
|
||||
return bool(len(self._write_buffer))
|
Loading…
Reference in a new issue