diff --git a/src/Cache.py b/src/Cache.py index 09c0230f..f28a0952 100644 --- a/src/Cache.py +++ b/src/Cache.py @@ -1,6 +1,7 @@ import hashlib, time, typing, uuid +from src import PollHook -class Cache(object): +class Cache(PollHook.PollHook): def __init__(self): self._items = {} @@ -21,7 +22,7 @@ class Cache(object): self._items[id] = [key, value, expiration] return id - def next_expiration(self) -> typing.Optional[float]: + def next(self) -> typing.Optional[float]: if not self._cached_expiration == None: return self._cached_expiration @@ -36,7 +37,7 @@ class Cache(object): self._cached_expiration = expiration return expiration - def expire(self): + def call(self): now = time.monotonic() expired = [] for id in self._items.keys(): diff --git a/src/IRCBot.py b/src/IRCBot.py index 2b81a223..b03f2037 100644 --- a/src/IRCBot.py +++ b/src/IRCBot.py @@ -1,7 +1,7 @@ import enum, queue, os, queue, select, socket, sys, threading, time, traceback import typing, uuid from src import EventManager, Exports, IRCServer, Logging, ModuleManager -from src import Socket, utils +from src import PollHook, Socket, utils VERSION = "v1.11.1" SOURCE = "https://git.io/bitbot" @@ -23,6 +23,17 @@ class TriggerEvent(object): class BitBotPanic(Exception): pass +class ListLambdaPollHook(PollHook.PollHook): + def __init__(self, + collection: typing.Callable[[], typing.Iterable[typing.Any]], + func: typing.Callable[[typing.Any], None]): + self._collection = collection + self._func = func + def next(self): + timeouts = [self._func(i) for i in self._collection()] + timeouts = [t for t in timeouts if t is not None] + return min(timeouts or [None]) + class Bot(object): def __init__(self, directory, args, cache, config, database, events, exports, log, modules, timers): @@ -57,8 +68,28 @@ class Bot(object): self._read_thread = None self._write_thread = None + self._poll_timeouts = [] # typing.List[PollHook] + self._poll_timeouts.append(self._timers) + self._poll_timeouts.append(self.cache) + + self._poll_timeouts.append(ListLambdaPollHook( + lambda: self.servers.values(), + lambda server: server.until_read_timeout())) + + self._poll_timeouts.append(ListLambdaPollHook( + lambda: self.servers.values(), + lambda server: server.until_next_ping())) + + self._poll_timeouts.append(ListLambdaPollHook( + lambda: self.servers.values(), self._throttle_timeout)) + self._events.on("timer.reconnect").hook(self._timed_reconnect) + def _throttle_timeout(self, server: IRCServer.Server): + if server.socket.waiting_throttled_send(): + return server.socket.send_throttle_timeout() + return None + def _trigger_both(self): self.trigger_read() self.trigger_write() @@ -173,40 +204,11 @@ class Bot(object): self._read_poll.register(server.fileno(), select.POLLIN) return True - def next_send(self) -> typing.Optional[float]: - next = None - for server in self.servers.values(): - timeout = server.socket.send_throttle_timeout() - if (server.socket.waiting_throttled_send() and - (next == None or timeout < next)): - next = timeout - return next - - def next_ping(self) -> typing.Optional[float]: - timeouts = [] - for server in self.servers.values(): - timeout = server.until_next_ping() - if not timeout == None: - timeouts.append(timeout) - if not timeouts: - return None - return min(timeouts) - - def next_read_timeout(self) -> typing.Optional[float]: - timeouts = [] - for server in self.servers.values(): - timeouts.append(server.until_read_timeout()) - if not timeouts: - return None - return min(timeouts) - def get_poll_timeout(self) -> float: timeouts = [] - timeouts.append(self._timers.next()) - timeouts.append(self.next_send()) - timeouts.append(self.next_ping()) - timeouts.append(self.next_read_timeout()) - timeouts.append(self.cache.next_expiration()) + for poll_timeout in self._poll_timeouts: + timeouts.append(poll_timeout.next()) + min_secs = min([timeout for timeout in timeouts if not timeout == None]) return max([min_secs, 0]) @@ -374,8 +376,8 @@ class Bot(object): server.disconnect() def _check(self): - self._timers.call() - self.cache.expire() + for poll_timeout in self._poll_timeouts: + poll_timeout.call() throttle_filled = False for server in list(self.servers.values()): diff --git a/src/PollHook.py b/src/PollHook.py new file mode 100644 index 00000000..a0e9d8c5 --- /dev/null +++ b/src/PollHook.py @@ -0,0 +1,7 @@ +import typing + +class PollHook(object): + def next(self) -> typing.Optional[float]: + return None + def call(self): + return None diff --git a/src/Timers.py b/src/Timers.py index 8eff1013..e74d4a53 100644 --- a/src/Timers.py +++ b/src/Timers.py @@ -1,5 +1,5 @@ import time, typing, uuid -from src import Database, EventManager, Logging +from src import Database, EventManager, Logging, PollHook class Timer(object): def __init__(self, id: str, context: typing.Optional[str], name: str, @@ -32,7 +32,7 @@ class Timer(object): def done(self) -> bool: return self._done -class Timers(object): +class Timers(PollHook.PollHook): def __init__(self, database: Database.Database, events: EventManager.Events, log: Logging.Log):