diff --git a/IRCBot.py b/IRCBot.py index c43dcb28..f803d61d 100644 --- a/IRCBot.py +++ b/IRCBot.py @@ -65,7 +65,7 @@ class Bot(object): next = None for timer in self.timers: time_left = timer.time_left() - if not next or time_left < next: + if next == None or time_left < next: next = time_left if next == None: @@ -83,18 +83,28 @@ class Bot(object): next = None for server in self.servers.values(): timeout = server.send_throttle_timeout() - if server.waiting_send() and (not next or timeout < next): + if server.waiting_send() and (next == None or timeout < next): next = timeout - if next == None: - return None - if next < 0: - return 0 return next + def next_ping(self): + timeouts = [] + for server in self.servers.values(): + timeouts.append(server.until_next_ping()) + return min(timeouts) + def next_read_timeout(self): + timeouts = [] + for server in self.servers.values(): + timeouts.append(server.until_read_timeout()) + return min(timeouts) + def get_poll_timeout(self): - next_timer = self.next_timer() or 30 - next_write = self.next_send() or 30 - return min(next_timer, next_write) + timeouts = [] + timeouts.append(self.next_timer()) + timeouts.append(self.next_send()) + timeouts.append(self.next_ping()) + timeouts.append(self.next_read_timeout()) + return min([timeout for timeout in timeouts if not timeout == None]) def register_read(self, server): self.poll.modify(server.fileno(), select.EPOLLIN) @@ -104,10 +114,6 @@ class Bot(object): self.poll.modify(server.fileno(), select.EPOLLIN|select.EPOLLOUT) - def since_last_read(self, server): - return None if not server.last_read else time.monotonic( - )-server.last_read - def disconnect(self, server): try: self.poll.unregister(server.fileno()) @@ -157,12 +163,10 @@ class Bot(object): server.disconnect() for server in list(self.servers.values()): - since_last_read = self.since_last_read(server) - if since_last_read: - if since_last_read > 120: - print("pingout from %s" % str(server)) - server.disconnect() - elif since_last_read > 30 and not server.ping_sent: + if server.read_timed_out(): + print("pingout from %s" % str(server)) + server.disconnect() + elif server.ping_due() and not server.ping_sent: server.send_ping() server.ping_sent = True if not server.connected: diff --git a/IRCServer.py b/IRCServer.py index 46dedbb7..8d4e4b64 100644 --- a/IRCServer.py +++ b/IRCServer.py @@ -1,9 +1,12 @@ import collections, socket, ssl, sys, time import IRCChannel, IRCUser -OUR_TLS_PROTOCOL = ssl.PROTOCOL_SSLv23 THROTTLE_LINES = 4 THROTTLE_SECONDS = 1 +READ_TIMEOUT_SECONDS = 120 +PING_INTERVAL_SECONDS = 30 + +OUR_TLS_PROTOCOL = ssl.PROTOCOL_SSLv23 if hasattr(ssl, "PROTOCOL_TLS"): OUR_TLS_PROTOCOL = ssl.PROTOCOL_TLS @@ -38,7 +41,7 @@ class Server(object): self.channel_modes = [] self.channel_types = [] - self.last_read = None + self.last_read = time.monotonic() self.last_send = None self.attempted_join = {} @@ -204,6 +207,19 @@ class Server(object): self.last_read = time.monotonic() self.ping_sent = False return decoded_lines + + def until_next_ping(self): + return max(0, (self.last_read+PING_INTERVAL_SECONDS + )-time.monotonic()) + def ping_due(self): + return self.until_next_ping() == 0 + + def until_read_timeout(self): + return max(0, (self.last_read+READ_TIMEOUT_SECONDS + )-time.monotonic()) + def read_timed_out(self): + return self.until_read_timeout == 0 + def send(self, data): encoded = data.split("\n")[0].strip("\r").encode("utf8") if len(encoded) > 450: