diff --git a/src/ControlSocket.py b/src/ControlSocket.py new file mode 100644 index 00000000..ead2624b --- /dev/null +++ b/src/ControlSocket.py @@ -0,0 +1,36 @@ +import os, socket +from src import Socket + +class ControlSocket(object): + def __init__(self, bot): + self.bot = bot + + location = bot.config["control-socket"] + if os.path.exists(location): + os.unlink(location) + self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.socket.bind(location) + self.socket.listen() + self.connected = True + + def fileno(self): + return self.socket.fileno() + def waiting_send(self): + return False + def _send(self): + pass + def read(self): + client, addr = self.socket.accept() + self.bot.add_socket(Socket.Socket(client, self.on_read)) + return [] + def parse_data(self, data): + command = data.split(" ", 1)[0].upper() + if command == "TRIGGER": + pass + else: + raise ValueError("unknown control socket command: '%s'" % + command) + + def on_read(self, sock, data): + data = data.strip("\r\n") + print(data) diff --git a/src/IRCBot.py b/src/IRCBot.py index e12e4d9e..29c45b9b 100644 --- a/src/IRCBot.py +++ b/src/IRCBot.py @@ -1,5 +1,6 @@ import os, select, sys, threading, time, traceback, uuid -from . import EventManager, Exports, IRCServer, Logging, ModuleManager +from src import ControlSocket, EventManager, Exports, IRCServer, Logging +from src import ModuleManager, utils class Bot(object): def __init__(self, directory, args, cache, config, database, events, @@ -15,13 +16,16 @@ class Bot(object): self.modules = modules self.timers = timers - events.on("timer.reconnect").hook(self.reconnect) self.start_time = time.time() self.lock = threading.Lock() - self.servers = {} self.running = True self.poll = select.epoll() + self.servers = {} + self.other_sockets = {} + self.control_socket = ControlSocket.ControlSocket(self) + self.add_socket(self.control_socket) + def add_server(self, server_id, connect=True): (_, alias, hostname, port, password, ipv4, tls, bindhost, nickname, username, realname) = self.database.servers.get(server_id) @@ -36,6 +40,14 @@ class Bot(object): self.connect(new_server) return new_server + def add_socket(self, sock): + self.other_sockets[sock.fileno()] = sock + self.poll.register(sock.fileno(), select.EPOLLIN) + + def remove_socket(self, sock): + del self.other_sockets[sock.fileno()] + self.poll.unregister(sock.fileno()) + def get_server(self, id): for server in self.servers.values(): if server.id == id: @@ -101,6 +113,7 @@ class Bot(object): pass del self.servers[server.fileno()] + @utils.hook("timer.reconnect") def reconnect(self, event): server = self.add_server(event["server_id"], False) if self.connect(server): @@ -128,19 +141,30 @@ class Bot(object): self.cache.expire() for fd, event in events: + sock = None + irc = False if fd in self.servers: - server = self.servers[fd] + sock = self.servers[fd] + irc = True + elif fd in self.other_sockets: + sock = self.other_sockets[fd] + + if sock: if event & select.EPOLLIN: - lines = server.read() - for line in lines: - self.log.debug("%s (raw) | %s", [str(server), line]) - server.parse_line(line) + data = sock.read() + if data == None: + sock.disconnect() + for piece in data: + if irc: + self.log.debug("%s (raw) | %s", + [str(sock), data]) + sock.parse_data(piece) elif event & select.EPOLLOUT: - server._send() - self.register_read(server) + sock._send() + self.register_read(sock) elif event & select.EPULLHUP: print("hangup") - server.disconnect() + sock.disconnect() for server in list(self.servers.values()): if server.read_timed_out(): @@ -160,4 +184,11 @@ class Bot(object): str(server), reconnect_delay)) elif server.waiting_send() and server.throttle_done(): self.register_both(server) + + for sock in list(self.other_sockets.values()): + if not sock.connected: + self.remove_socket(sock) + elif sock.waiting_send(): + self.register_both(sock) + self.lock.release() diff --git a/src/IRCServer.py b/src/IRCServer.py index 35050ff3..c389efed 100644 --- a/src/IRCServer.py +++ b/src/IRCServer.py @@ -199,7 +199,7 @@ class Server(IRCObject.Object): for user in channel.users: user.part_channel(channel) del self.channels[channel.name] - def parse_line(self, line): + def parse_data(self, line): if not line: return self.events.on("raw").call(server=self, line=line) @@ -212,16 +212,22 @@ class Server(IRCObject.Object): def read(self): data = b"" try: - data = self.read_buffer + self.socket.recv(4096) + data = self.socket.recv(4096) except (ConnectionResetError, socket.timeout): self.disconnect() - return [] + return None + if not data: + self.disconnect() + return None + data = self.read_buffer+data self.read_buffer = b"" + data_lines = [line.strip(b"\r") for line in data.split(b"\n")] if data_lines[-1]: self.read_buffer = data_lines[-1] data_lines.pop(-1) decoded_lines = [] + for line in data_lines: try: line = line.decode(self.get_setting( @@ -233,8 +239,7 @@ class Server(IRCObject.Object): except: continue decoded_lines.append(line) - if not decoded_lines: - self.disconnect() + self.last_read = time.monotonic() self.ping_sent = False return decoded_lines diff --git a/src/Socket.py b/src/Socket.py new file mode 100644 index 00000000..f405b0a9 --- /dev/null +++ b/src/Socket.py @@ -0,0 +1,51 @@ + + +class Socket(object): + def __init__(self, socket, on_read, encoding="utf8"): + self.socket = socket + self._on_read = on_read + self.encoding = encoding + + self._write_buffer = b"" + self._read_buffer = b"" + self.delimiter = None + self.length = None + self.connected = True + + def fileno(self): + return self.socket.fileno() + + def disconnect(self): + self.connected = False + + def _decode(self, s): + return s.decode(self.encoding) if self.encoding else s + def _encode(self, s): + return s.encode(self.encoding) if self.encoding else s + + def read(self): + data = self.socket.recv(1024) + if not data: + return None + + data = self._read_buffer+data + self._read_buffer = b"" + if not self.delimiter == None: + data_split = data.split(delimiter) + if data_split[-1]: + self._read_buffer = data_split.pop(-1) + return [self._decode(data) for data in data_split] + return [data.decode(self.encoding)] + + def parse_data(self, data): + self._on_read(self, data) + + def send(self, data): + self._write_buffer += self._encode(data) + + def _send(self): + self._write_buffer = self._write_buffer[self.socket.send( + self._write_buffer):] + + def waiting_send(self): + return bool(len(self._write_buffer))