Split read/write/process in to 3 different threads
This commit is contained in:
parent
2ca4fd48f7
commit
a1ebe8035e
3 changed files with 139 additions and 105 deletions
210
src/IRCBot.py
210
src/IRCBot.py
|
@ -1,4 +1,5 @@
|
|||
import enum, queue, os, select, socket, threading, time, traceback, typing, uuid
|
||||
import enum, queue, os, queue, select, socket, threading, time, traceback
|
||||
import typing, uuid
|
||||
from src import EventManager, Exports, IRCServer, Logging, ModuleManager
|
||||
from src import Socket, utils
|
||||
|
||||
|
@ -24,18 +25,31 @@ class Bot(object):
|
|||
self._timers = timers
|
||||
|
||||
self.start_time = time.time()
|
||||
self.lock = threading.Lock()
|
||||
self.running = True
|
||||
self.poll = select.epoll()
|
||||
|
||||
self.servers = {}
|
||||
self.other_sockets = {}
|
||||
self._trigger_server, self._trigger_client = socket.socketpair()
|
||||
self.add_socket(Socket.Socket(self._trigger_server, lambda _, s: None))
|
||||
|
||||
self._event_queue = queue.Queue()
|
||||
|
||||
self._read_poll = select.epoll()
|
||||
self._write_poll = select.epoll()
|
||||
|
||||
self._rtrigger_server, self._rtrigger_client = socket.socketpair()
|
||||
self._read_poll.register(self._rtrigger_server.fileno(), select.EPOLLIN)
|
||||
|
||||
self._wtrigger_server, self._wtrigger_client = socket.socketpair()
|
||||
self._write_poll.register(self._wtrigger_server.fileno(),
|
||||
select.EPOLLIN)
|
||||
|
||||
self._read_thread = None
|
||||
self._write_thread = None
|
||||
|
||||
self._trigger_functions = []
|
||||
self._events.on("timer.reconnect").hook(self._timed_reconnect)
|
||||
|
||||
def _thread_trigger(self):
|
||||
self._rtrigger_client.send(b"TRIGGER")
|
||||
self._wtrigger_client.send(b"TRIGGER")
|
||||
|
||||
def trigger(self,
|
||||
func: typing.Optional[typing.Callable[[], typing.Any]]=None
|
||||
) -> typing.Any:
|
||||
|
@ -43,18 +57,25 @@ class Bot(object):
|
|||
|
||||
if utils.is_main_thread():
|
||||
returned = func()
|
||||
self._trigger_client.send(b"TRIGGER")
|
||||
self._thread_trigger()
|
||||
return returned
|
||||
|
||||
self.lock.acquire()
|
||||
|
||||
func_queue = queue.Queue(1) # type: queue.Queue[str]
|
||||
self._trigger_functions.append([func, func_queue])
|
||||
|
||||
self.lock.release()
|
||||
self._trigger_client.send(b"TRIGGER")
|
||||
def _action():
|
||||
try:
|
||||
returned = func()
|
||||
type = TriggerResult.Return
|
||||
except Exception as e:
|
||||
returned = e
|
||||
type = TriggerResult.Exception
|
||||
func_queue.put([type, returned])
|
||||
self._event_queue.put(_action)
|
||||
|
||||
type, returned = func_queue.get(block=True)
|
||||
|
||||
self._thread_trigger()
|
||||
|
||||
if type == TriggerResult.Exception:
|
||||
raise returned
|
||||
elif type == TriggerResult.Return:
|
||||
|
@ -95,14 +116,6 @@ class Bot(object):
|
|||
|
||||
return new_server
|
||||
|
||||
def add_socket(self, sock: socket.socket):
|
||||
self.other_sockets[sock.fileno()] = sock
|
||||
self.poll.register(sock.fileno(), select.EPOLLIN)
|
||||
|
||||
def remove_socket(self, sock: socket.socket):
|
||||
del self.other_sockets[sock.fileno()]
|
||||
self.poll.unregister(sock.fileno())
|
||||
|
||||
def get_server_by_id(self, id: int) -> typing.Optional[IRCServer.Server]:
|
||||
for server in self.servers.values():
|
||||
if server.id == id:
|
||||
|
@ -123,14 +136,14 @@ class Bot(object):
|
|||
[str(server), str(e)])
|
||||
return False
|
||||
self.servers[server.fileno()] = server
|
||||
self.poll.register(server.fileno(), select.EPOLLOUT)
|
||||
self._read_poll.register(server.fileno(), select.EPOLLIN)
|
||||
return True
|
||||
|
||||
def next_send(self) -> typing.Optional[float]:
|
||||
next = None
|
||||
for server in self.servers.values():
|
||||
timeout = server.socket.send_throttle_timeout()
|
||||
if (server.socket.waiting_send() and
|
||||
if (server.socket.waiting_throttled_send() and
|
||||
(next == None or timeout < next)):
|
||||
next = timeout
|
||||
return next
|
||||
|
@ -162,17 +175,13 @@ class Bot(object):
|
|||
timeouts.append(self.cache.next_expiration())
|
||||
return min([timeout for timeout in timeouts if not timeout == None])
|
||||
|
||||
def register_read(self, server: IRCServer.Server):
|
||||
self.poll.modify(server.fileno(), select.EPOLLIN)
|
||||
def register_write(self, server: IRCServer.Server):
|
||||
self.poll.modify(server.fileno(), select.EPOLLOUT)
|
||||
def register_both(self, server: IRCServer.Server):
|
||||
self.poll.modify(server.fileno(),
|
||||
select.EPOLLIN|select.EPOLLOUT)
|
||||
|
||||
def disconnect(self, server: IRCServer.Server):
|
||||
try:
|
||||
self.poll.unregister(server.fileno())
|
||||
self._read_poll.unregister(server.fileno())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
try:
|
||||
self._write_poll.unregister(server.fileno())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
del self.servers[server.fileno()]
|
||||
|
@ -208,13 +217,54 @@ class Bot(object):
|
|||
def del_setting(self, setting: str):
|
||||
self.database.bot_settings.delete(setting)
|
||||
|
||||
def _daemon_thread(self, target: typing.Callable[[], None]):
|
||||
thread = threading.Thread(target=target)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
def run(self):
|
||||
self._read_thread = self._daemon_thread(self._read_loop)
|
||||
self._write_thread = self._daemon_thread(self._write_loop)
|
||||
self._event_loop()
|
||||
|
||||
def _event_loop(self):
|
||||
while self.running:
|
||||
item = self._event_queue.get(block=True, timeout=None)
|
||||
item()
|
||||
|
||||
def _write_loop(self):
|
||||
while self.running:
|
||||
if not self.servers:
|
||||
break
|
||||
|
||||
events = self.poll.poll(self.get_poll_timeout())
|
||||
self.lock.acquire()
|
||||
for fd, server in self.servers.items():
|
||||
if server.socket.waiting_immediate_send():
|
||||
self._write_poll.register(fd, select.EPOLLOUT)
|
||||
|
||||
events = self._write_poll.poll()
|
||||
for fd, event in events:
|
||||
if fd == self._wtrigger_server.fileno():
|
||||
# throw away data from trigger socket
|
||||
self._wtrigger_server.recv(1024)
|
||||
elif event & select.EPOLLOUT:
|
||||
self._write_poll.unregister(fd)
|
||||
server = self.servers[fd]
|
||||
|
||||
try:
|
||||
lines = server.socket._send()
|
||||
except:
|
||||
self.log.error("Failed to write to %s", [str(server)])
|
||||
raise
|
||||
self._event_queue.put(lambda: server._post_send(lines))
|
||||
|
||||
def _read_loop(self):
|
||||
while self.running:
|
||||
if not self.servers:
|
||||
self.running = False
|
||||
break
|
||||
|
||||
events = self._read_poll.poll(self.get_poll_timeout())
|
||||
self._timers.call()
|
||||
self.cache.expire()
|
||||
|
||||
|
@ -229,64 +279,48 @@ class Bot(object):
|
|||
self._trigger_functions.clear()
|
||||
|
||||
for fd, event in events:
|
||||
sock = None
|
||||
irc = False
|
||||
if fd in self.servers:
|
||||
sock = self.servers[fd]
|
||||
irc = True
|
||||
elif fd in self.other_sockets:
|
||||
sock = self.other_sockets[fd]
|
||||
|
||||
if sock:
|
||||
if fd == self._rtrigger_server.fileno():
|
||||
# throw away data from trigger socket
|
||||
self._rtrigger_server.recv(1024)
|
||||
else:
|
||||
server = self.servers[fd]
|
||||
if event & select.EPOLLIN:
|
||||
data = sock.read()
|
||||
if data == None:
|
||||
sock.disconnect()
|
||||
lines = server.read()
|
||||
if lines == None:
|
||||
server.disconnect()
|
||||
continue
|
||||
|
||||
for piece in data:
|
||||
sock.parse_data(piece)
|
||||
elif event & select.EPOLLOUT:
|
||||
try:
|
||||
sock._send()
|
||||
except:
|
||||
self.log.error("Failed to write to %s",
|
||||
[str(sock)])
|
||||
raise
|
||||
self._event_queue.put(lambda: server._post_read(lines))
|
||||
elif event & select.EPOLLHUP:
|
||||
self.log.warn("Recieved EPOLLHUP for %s", [str(server)])
|
||||
server.disconnect()
|
||||
|
||||
if sock.fileno() in self.servers:
|
||||
self.register_read(sock)
|
||||
elif event & select.EPULLHUP:
|
||||
self.log.warn("Recieved EPOLLHUP for %s", [str(sock)])
|
||||
sock.disconnect()
|
||||
self.trigger(self._check_servers)
|
||||
|
||||
for server in list(self.servers.values()):
|
||||
if server.read_timed_out():
|
||||
self.log.warn("Pinged out from %s", [str(server)])
|
||||
server.disconnect()
|
||||
elif server.ping_due() and not server.ping_sent:
|
||||
server.send_ping()
|
||||
server.ping_sent = True
|
||||
if not server.socket.connected:
|
||||
self._events.on("server.disconnect").call(server=server)
|
||||
self.disconnect(server)
|
||||
def _check_servers(self):
|
||||
throttle_filled = False
|
||||
for server in list(self.servers.values()):
|
||||
if server.read_timed_out():
|
||||
self.log.warn("Pinged out from %s", [str(server)])
|
||||
server.disconnect()
|
||||
elif server.ping_due() and not server.ping_sent:
|
||||
server.send_ping()
|
||||
server.ping_sent = True
|
||||
if not server.socket.connected:
|
||||
self._events.on("server.disconnect").call(server=server)
|
||||
self.disconnect(server)
|
||||
|
||||
if not self.get_server_by_id(server.id):
|
||||
reconnect_delay = self.config.get("reconnect-delay", 10)
|
||||
self._timers.add("reconnect", reconnect_delay,
|
||||
server_id=server.id)
|
||||
self.log.warn(
|
||||
"Disconnected from %s, reconnecting in %d seconds",
|
||||
[str(server), reconnect_delay])
|
||||
elif server.socket.waiting_immediate_send() or (
|
||||
server.socket.waiting_send() and
|
||||
server.socket.throttle_done()):
|
||||
self.register_both(server)
|
||||
if not self.get_server_by_id(server.id):
|
||||
reconnect_delay = self.config.get("reconnect-delay", 10)
|
||||
self._timers.add("reconnect", reconnect_delay,
|
||||
server_id=server.id)
|
||||
self.log.warn(
|
||||
"Disconnected from %s, reconnecting in %d seconds",
|
||||
[str(server), reconnect_delay])
|
||||
elif (server.socket.waiting_throttled_send() and
|
||||
server.socket.throttle_done()):
|
||||
server.socket._fill_throttle()
|
||||
throttle_filled = True
|
||||
|
||||
for sock in list(self.other_sockets.values()):
|
||||
if not sock.connected:
|
||||
self.remove_socket(sock)
|
||||
elif sock.waiting_send():
|
||||
self.register_both(sock)
|
||||
|
||||
self.lock.release()
|
||||
if throttle_filled:
|
||||
self._wtrigger_client.send(b"TRIGGER")
|
||||
|
|
|
@ -207,14 +207,12 @@ class Server(IRCObject.Object):
|
|||
return utils.irc.hostmask_match(self.irc_lower(hostmask),
|
||||
self.irc_lower(pattern))
|
||||
|
||||
def parse_data(self, line: str):
|
||||
if not line:
|
||||
return
|
||||
|
||||
self.bot.log.debug("%s (raw recv) | %s", [str(self), line])
|
||||
self.events.on("raw.received").call_unsafe(server=self,
|
||||
line=utils.irc.parse_line(line))
|
||||
self.check_users()
|
||||
def _post_read(self, lines: typing.List[str]):
|
||||
for line in lines:
|
||||
self.bot.log.debug("%s (raw recv) | %s", [str(self), line])
|
||||
self.events.on("raw.received").call_unsafe(server=self,
|
||||
line=utils.irc.parse_line(line))
|
||||
self.check_users()
|
||||
def check_users(self):
|
||||
for user in self.new_users:
|
||||
if not len(user.channels):
|
||||
|
@ -261,8 +259,8 @@ class Server(IRCObject.Object):
|
|||
def send_raw(self, line: str):
|
||||
return self.send(utils.irc.parse_line(line))
|
||||
|
||||
def _send(self):
|
||||
lines = self.socket._send()
|
||||
|
||||
def _post_send(self, lines: typing.List[IRCLine.SentLine]):
|
||||
for line in lines:
|
||||
self.bot.log.debug("%s (raw send) | %s", [
|
||||
str(self), line.parsed_line.format()])
|
||||
|
|
|
@ -129,12 +129,8 @@ class Socket(IRCObject.Object):
|
|||
else:
|
||||
self._queued_lines.append(line)
|
||||
|
||||
def _send(self) -> typing.List[IRCLine.SentLine]:
|
||||
if not self._write_buffer and self._throttle_when_empty:
|
||||
self._throttle_when_empty = False
|
||||
self._write_throttling = True
|
||||
self._recent_sends.clear()
|
||||
|
||||
def _fill_throttle(self):
|
||||
throttle_space = self.throttle_space()
|
||||
if throttle_space:
|
||||
to_buffer = self._queued_lines[:throttle_space]
|
||||
|
@ -142,6 +138,12 @@ class Socket(IRCObject.Object):
|
|||
for line in to_buffer:
|
||||
self._immediate_buffer(line)
|
||||
|
||||
def _send(self) -> typing.List[IRCLine.SentLine]:
|
||||
if not self._write_buffer and self._throttle_when_empty:
|
||||
self._throttle_when_empty = False
|
||||
self._write_throttling = True
|
||||
self._recent_sends.clear()
|
||||
|
||||
bytes_written_i = self._socket.send(self._write_buffer)
|
||||
bytes_written = self._write_buffer[:bytes_written_i]
|
||||
|
||||
|
@ -165,8 +167,8 @@ class Socket(IRCObject.Object):
|
|||
def clear_send_buffer(self):
|
||||
self._queued_lines.clear()
|
||||
|
||||
def waiting_send(self) -> bool:
|
||||
return bool(len(self._write_buffer)) or bool(len(self._queued_lines))
|
||||
def waiting_throttled_send(self) -> bool:
|
||||
return bool(len(self._queued_lines))
|
||||
def waiting_immediate_send(self) -> bool:
|
||||
return bool(len(self._write_buffer))
|
||||
|
||||
|
|
Loading…
Reference in a new issue