add "poll sources" - objects that can provide additional filenos for polling

This commit is contained in:
jesopo 2019-10-11 13:59:28 +01:00
parent b7b045eadb
commit 175f90f6a2
2 changed files with 48 additions and 5 deletions

View file

@ -1,7 +1,7 @@
import enum, queue, os, queue, select, socket, sys, threading, time, traceback import enum, queue, os, queue, select, socket, sys, threading, time, traceback
import typing, uuid import typing, uuid
from src import EventManager, Exports, IRCServer, Logging, ModuleManager from src import EventManager, Exports, IRCServer, Logging, ModuleManager
from src import PollHook, Socket, Timers, utils from src import PollHook, PollSource, Socket, Timers, utils
VERSION = "v1.12.0" VERSION = "v1.12.0"
SOURCE = "https://git.io/bitbot" SOURCE = "https://git.io/bitbot"
@ -80,8 +80,12 @@ class Bot(object):
self._poll_timeouts.append(ListLambdaPollHook( self._poll_timeouts.append(ListLambdaPollHook(
lambda: self.servers.values(), self._throttle_timeout)) lambda: self.servers.values(), self._throttle_timeout))
self._poll_sources = [] # typing.List[PollSource.PollSource]
def add_poll_hook(self, hook: PollHook.PollHook): def add_poll_hook(self, hook: PollHook.PollHook):
self._poll_timeouts.append(hook) self._poll_timeouts.append(hook)
def add_poll_source(self, source: PollSource.PollSource):
self._poll_sources.append(source)
def _throttle_timeout(self, server: IRCServer.Server): def _throttle_timeout(self, server: IRCServer.Server):
if server.socket.waiting_throttled_send(): if server.socket.waiting_throttled_send():
@ -311,16 +315,24 @@ class Bot(object):
def _write_loop(self): def _write_loop(self):
while self.running: while self.running:
poll_sources = {}
with self._write_condition: with self._write_condition:
writeable = False fds = []
for fd, server in self.servers.items(): for fd, server in self.servers.items():
if server.socket.waiting_immediate_send(): if server.socket.waiting_immediate_send():
self._write_poll.register(fd, select.POLLOUT) fds.append(fd)
writeable = True
if not writeable: for poll_source in self._poll_sources:
for fileno in poll_source.get_writables():
poll_sources[fileno] = poll_source
fds.append(fileno)
if not fds:
self._write_condition.wait() self._write_condition.wait()
continue continue
else:
for fd in fds:
self._write_poll.register(fd, select.POLLOUT)
events = self._write_poll.poll() events = self._write_poll.poll()
@ -339,9 +351,25 @@ class Bot(object):
event_item = TriggerEvent(TriggerEventType.Action, event_item = TriggerEvent(TriggerEventType.Action,
self._post_send_factory(server, lines)) self._post_send_factory(server, lines))
self._event_queue.put(event_item) self._event_queue.put(event_item)
elif fd in poll_sources:
poll_sources[fd].is_writeable(fd)
def _read_loop(self): def _read_loop(self):
poll_sources = {}
while self.running: while self.running:
new_poll_sources = {}
for poll_source in self._poll_sources:
for fileno in poll_source.get_readables():
new_poll_sources[fileno] = poll_source
for fileno in new_poll_sources:
if not fileno in poll_sources:
poll_sources[fileno] = new_poll_sources[fileno]
self._read_poll.register(fileno, select.POLLIN)
for fileno in list(poll_sources.keys()):
if not fileno in new_poll_sources:
del poll_sources[fileno]
self._read_poll.unregister(fileno)
events = self._read_poll.poll() events = self._read_poll.poll()
for fd, event in events: for fd, event in events:
@ -350,6 +378,9 @@ class Bot(object):
with self._rtrigger_lock: with self._rtrigger_lock:
self._rtrigger_server.recv(1024) self._rtrigger_server.recv(1024)
self._rtriggered = False self._rtriggered = False
elif fd in poll_sources:
poll_sources[fd].is_readable(fd)
self.trigger_write()
else: else:
if not fd in self.servers: if not fd in self.servers:
self._read_poll.unregister(fd) self._read_poll.unregister(fd)

12
src/PollSource.py Normal file
View file

@ -0,0 +1,12 @@
import typing
class PollSource(object):
def get_readables(self) -> typing.List[int]:
return []
def get_writables(self) -> typing.List[int]:
return []
def is_readable(self, fileno: int):
pass
def is_writable(self, fileno: int):
pass