diff --git a/modules/line_handler.py b/modules/line_handler.py index 22d505fa..bac0245b 100644 --- a/modules/line_handler.py +++ b/modules/line_handler.py @@ -69,7 +69,7 @@ class Module(ModuleManager.BaseModule): # first numeric line the server sends @utils.hook("raw.received.001", default_event=True) def handle_001(self, event): - event["server"].set_write_throttling(True) + event["server"].socket.set_write_throttling(True) event["server"].name = event["prefix"].hostmask event["server"].set_own_nickname(event["args"][0]) event["server"].send_whois(event["server"].nickname) diff --git a/src/IRCBot.py b/src/IRCBot.py index 6caf720b..59b3837f 100644 --- a/src/IRCBot.py +++ b/src/IRCBot.py @@ -101,8 +101,9 @@ class Bot(object): def next_send(self) -> typing.Optional[float]: next = None for server in self.servers.values(): - timeout = server.send_throttle_timeout() - if server.waiting_send() and (next == None or timeout < next): + timeout = server.socket.send_throttle_timeout() + if (server.socket.waiting_send() and + (next == None or timeout < next)): next = timeout return next @@ -238,7 +239,8 @@ class Bot(object): self.log.warn( "Disconnected from %s, reconnecting in %d seconds", [str(server), reconnect_delay]) - elif server.waiting_send() and server.throttle_done(): + elif (server.socket.waiting_send() and + server.socket.throttle_done()): self.register_both(server) for sock in list(self.other_sockets.values()): diff --git a/src/IRCServer.py b/src/IRCServer.py index 2907b8eb..bc800ed4 100644 --- a/src/IRCServer.py +++ b/src/IRCServer.py @@ -1,10 +1,7 @@ -import collections, datetime, socket, ssl, sys, time, typing -from src import EventManager, IRCBot, IRCChannel, IRCChannels, IRCLine -from src import IRCObject, IRCUser, utils -THROTTLE_LINES = 4 -THROTTLE_SECONDS = 1 -UNTHROTTLED_MAX_LINES = 10 +import collections, datetime, sys, time, typing +from src import EventManager, IRCBot, IRCChannel, IRCChannels, IRCLine +from src import IRCObject, IRCSocket, IRCUser, utils READ_TIMEOUT_SECONDS = 120 PING_INTERVAL_SECONDS = 30 @@ -37,16 +34,6 @@ class Server(IRCObject.Object): self.batches = {} # type: typing.Dict[str, utils.irc.IRCParsedLine] self.cap_started = False - self.write_buffer = b"" - self.queued_lines = [] # type: typing.List[IRCLine.Line] - self.buffered_lines = [] # type: typing.List[IRCLine.Line] - self._write_throttling = False - self.read_buffer = b"" - self.recent_sends = [] # type: typing.List[float] - self.cached_fileno = None # type: typing.Optional[int] - self.bytes_written = 0 - self.bytes_read = 0 - self.users = {} # type: typing.Dict[str, IRCUser.User] self.new_users = set([]) #type: typing.Set[IRCUser.User] self.channels = IRCChannels.Channels(self, self.bot, self.events) @@ -88,40 +75,27 @@ class Server(IRCObject.Object): return "%s:%s%s" % (self.connection_params.hostname, "+" if self.connection_params.tls else "", self.connection_params.port) + def fileno(self) -> int: - return self.cached_fileno or self.socket.fileno() + return self.socket.fileno() def hostmask(self): return "%s!%s@%s" % (self.nickname, self.username, self.hostname) - def tls_wrap(self): - client_certificate = self.bot.config.get("tls-certificate", None) - client_key = self.bot.config.get("tls-key", None) - verify = self.get_setting("ssl-verify", True) - - server_hostname = None - if not utils.is_ip(self.connection_params.hostname): - server_hostname = self.connection_params.hostname - - self.socket = utils.security.ssl_wrap(self.socket, - cert=client_certificate, key=client_key, - verify=verify, hostname=server_hostname) - def connect(self): - ipv4 = self.connection_params.ipv4 - family = socket.AF_INET if ipv4 else socket.AF_INET6 - self.socket = socket.socket(family, socket.SOCK_STREAM) - - self.socket.settimeout(5.0) - - if self.connection_params.bindhost: - self.socket.bind((self.connection_params.bindhost, 0)) - if self.connection_params.tls: - self.tls_wrap() - - self.socket.connect((self.connection_params.hostname, - self.connection_params.port)) - self.cached_fileno = self.socket.fileno() + self.socket = IRCSocket.Socket( + self.bot.log, + self.get_setting("encoding", "utf8"), + self.get_setting("fallback-encoding", "iso-8859-1"), + self.connection_params.hostname, + self.connection_params.port, + self.connection_params.ipv4, + self.connection_params.bindhost, + self.connection_params.tls, + tls_verify=self.get_setting("ssl-verify", True), + cert=self.bot.config.get("tls-certificate", None), + key=self.bot.config.get("tls-key", None)) + self.socket.connect() if self.connection_params.password: self.send_pass(self.connection_params.password) @@ -135,16 +109,9 @@ class Server(IRCObject.Object): self.send_user(username, realname) self.send_nick(nickname) self.connected = True + def disconnect(self): - self.connected = False - try: - self.socket.shutdown(socket.SHUT_RDWR) - except: - pass - try: - self.socket.close() - except: - pass + self.socket.disconnect() def set_setting(self, setting: str, value: typing.Any): self.bot.database.server_settings.set(self.id, setting, @@ -252,46 +219,6 @@ class Server(IRCObject.Object): if not len(user.channels): self.remove_user(user) self.new_users.clear() - def read(self) -> typing.Optional[typing.List[str]]: - data = b"" - try: - data = self.socket.recv(4096) - except (ConnectionResetError, socket.timeout, OSError): - self.disconnect() - return None - if not data: - self.disconnect() - return None - self.bytes_read += len(data) - data = self.read_buffer+data - self.read_buffer = b"" - - data_lines = [line.strip(b"\r") for line in data.split(b"\n")] - if data_lines[-1]: - self.read_buffer = data_lines[-1] - self.bot.log.trace("recevied and buffered non-complete line: %s", - [data_lines[-1]]) - - data_lines.pop(-1) - decoded_lines = [] - - for line in data_lines: - encoding = self.get_setting("encoding", "utf8") - try: - decoded_line = line.decode(encoding) - except: - self.bot.log.trace("can't decode line with '%s', falling back", - [encoding]) - try: - decoded_line = line.decode(self.get_setting( - "fallback-encoding", "latin-1")) - except: - continue - decoded_lines.append(decoded_line) - - self.last_read = time.monotonic() - self.ping_sent = False - return decoded_lines def until_next_ping(self) -> typing.Optional[float]: if self.ping_sent: @@ -307,6 +234,8 @@ class Server(IRCObject.Object): def read_timed_out(self) -> bool: return self.until_read_timeout == 0 + def read(self) -> typing.Optional[typing.List[str]]: + return self.socket.read() def send(self, line: str): results = self.events.on("preprocess.send").call_unsafe( server=self, line=line) @@ -314,75 +243,15 @@ class Server(IRCObject.Object): if result: line = result break - line_stripped = line.split("\n", 1)[0].strip("\r") line_obj = IRCLine.Line(self, datetime.datetime.utcnow(), line_stripped) - self.queued_lines.append(line_obj) - + self.socket.send(line_obj) return line_obj - def _send(self): - if not len(self.write_buffer): - throttle_space = self.throttle_space() - to_buffer = self.queued_lines[:throttle_space] - self.queued_lines = self.queued_lines[throttle_space:] - for line in to_buffer: - decoded_data = line.decoded_data() - self.bot.log.debug("%s (raw send) | %s", - [str(self), decoded_data]) - self.events.on("raw.send").call_unsafe( - server=self, line=decoded_data) - - self.write_buffer += line.data() - self.buffered_lines.append(line) - - bytes_written_i = self.socket.send(self.write_buffer) - bytes_written = self.write_buffer[:bytes_written_i] - lines_sent = bytes_written.count(b"\r\n") - for i in range(lines_sent): - self.buffered_lines.pop(0).sent() - - self.write_buffer = self.write_buffer[bytes_written_i:] - - self.bytes_written += bytes_written_i - - now = time.monotonic() - self.recent_sends.append(now) - self.last_send = now - def waiting_send(self) -> bool: - return bool(len(self.write_buffer)) or bool(len(self.queued_lines)) - - def throttle_done(self) -> bool: - return self.send_throttle_timeout() == 0 - - def throttle_prune(self): - now = time.monotonic() - popped = 0 - for i, recent_send in enumerate(self.recent_sends[:]): - time_since = now-recent_send - if time_since >= THROTTLE_SECONDS: - self.recent_sends.pop(i-popped) - popped += 1 - - def throttle_space(self) -> int: - if not self._write_throttling: - return UNTHROTTLED_MAX_LINES - return max(0, THROTTLE_LINES-len(self.recent_sends)) - - def send_throttle_timeout(self) -> float: - if len(self.write_buffer) or not self._write_throttling: - return 0 - - self.throttle_prune() - if self.throttle_space() > 0: - return 0 - - time_left = self.recent_sends[0]+THROTTLE_SECONDS - time_left = time_left-time.monotonic() - return time_left - - def set_write_throttling(self, is_on: bool): - self._write_throttling = is_on + lines = self.socket._send() + for line in lines: + self.bot.log.debug("%s (raw send) | %s", [str(self), line]) + self.events.on("raw.send").call_unsafe(server=self, line=line) def send_user(self, username: str, realname: str) -> IRCLine.Line: return self.send("USER %s 0 * :%s" % (username, realname)) diff --git a/src/IRCSocket.py b/src/IRCSocket.py new file mode 100644 index 00000000..e9379ddd --- /dev/null +++ b/src/IRCSocket.py @@ -0,0 +1,177 @@ +import datetime, socket, ssl, time, typing +from src import IRCLine, Logging, IRCObject, utils + +THROTTLE_LINES = 4 +THROTTLE_SECONDS = 1 +UNTHROTTLED_MAX_LINES = 10 + +class Socket(IRCObject.Object): + def __init__(self, log: Logging.Log, encoding: str, fallback_encoding: str, + hostname: str, port: int, ipv4: bool, bindhost: str, tls: bool, + tls_verify: bool=True, cert: str=None, key: str=None): + self.log = log + + self._encoding = encoding + self._fallback_encoding = fallback_encoding + self._hostname = hostname + self._port = port + self._ipv4 = ipv4 + self._bindhost = bindhost + + self._tls = tls + self._tls_verify = tls_verify + self._cert = cert + self._key = key + + self._write_buffer = b"" + self._queued_lines = [] # type: typing.List[IRCLine.Line] + self._buffered_lines = [] # type: typing.List[IRCLine.Line] + self._write_throttling = False + self._read_buffer = b"" + self._recent_sends = [] # type: typing.List[float] + self.cached_fileno = None # type: typing.Optional[int] + self.bytes_written = 0 + self.bytes_read = 0 + + def fileno(self) -> int: + return self.cached_fileno or self._socket.fileno() + + def _tls_wrap(self): + server_hostname = None + if not utils.is_ip(self._hostname): + server_hostname = self._hostname + + self._socket = utils.security.ssl_wrap(self._socket, + cert=self._cert, key=self._key, verify=self._tls_verify, + hostname=server_hostname) + + def connect(self): + family = socket.AF_INET if self._ipv4 else socket.AF_INET6 + self._socket = socket.socket(family, socket.SOCK_STREAM) + + self._socket.settimeout(5.0) + + if self._bindhost: + self._socket.bind((self._bindhost, 0)) + if self._tls: + self._tls_wrap() + + self._socket.connect((self._hostname, self._port)) + self.cached_fileno = self._socket.fileno() + + def disconnect(self): + self.connected = False + try: + self._socket.shutdown(socket.SHUT_RDWR) + except: + pass + try: + self._socket.close() + except: + pass + + def read(self) -> typing.Optional[typing.List[str]]: + data = b"" + try: + data = self._socket.recv(4096) + except (ConnectionResetError, socket.timeout, OSError): + self.disconnect() + return None + if not data: + self.disconnect() + return None + self.bytes_read += len(data) + data = self._read_buffer+data + self._read_buffer = b"" + + data_lines = [line.strip(b"\r") for line in data.split(b"\n")] + if data_lines[-1]: + self._read_buffer = data_lines[-1] + self.log.trace("recevied and buffered non-complete line: %s", + [data_lines[-1]]) + + data_lines.pop(-1) + decoded_lines = [] + + for line in data_lines: + try: + decoded_line = line.decode(self._encoding) + except: + self.log.trace("can't decode line with '%s', falling back", + [self._encoding]) + try: + decoded_line = line.decode(self._fallback_encoding) + except: + continue + decoded_lines.append(decoded_line) + + self.last_read = time.monotonic() + self.ping_sent = False + return decoded_lines + + def send(self, line: IRCLine.Line): + self._queued_lines.append(line) + + def _send(self) -> typing.List[str]: + decoded_sent = [] + if not len(self._write_buffer): + throttle_space = self.throttle_space() + to_buffer = self._queued_lines[:throttle_space] + self._queued_lines = self._queued_lines[throttle_space:] + for line in to_buffer: + decoded_data = line.decoded_data() + decoded_sent.append(decoded_data) + + self._write_buffer += line.data() + self._buffered_lines.append(line) + + bytes_written_i = self._socket.send(self._write_buffer) + bytes_written = self._write_buffer[:bytes_written_i] + lines_sent = bytes_written.count(b"\r\n") + for i in range(lines_sent): + self._buffered_lines.pop(0).sent() + + self._write_buffer = self._write_buffer[bytes_written_i:] + + self.bytes_written += bytes_written_i + + now = time.monotonic() + self._recent_sends.append(now) + self.last_send = now + + return decoded_sent + + def waiting_send(self) -> bool: + return bool(len(self._write_buffer)) or bool(len(self._queued_lines)) + + def throttle_done(self) -> bool: + return self.send_throttle_timeout() == 0 + + def throttle_prune(self): + now = time.monotonic() + popped = 0 + for i, recent_send in enumerate(self._recent_sends[:]): + time_since = now-recent_send + if time_since >= THROTTLE_SECONDS: + self._recent_sends.pop(i-popped) + popped += 1 + + def throttle_space(self) -> int: + if not self._write_throttling: + return UNTHROTTLED_MAX_LINES + return max(0, THROTTLE_LINES-len(self._recent_sends)) + + def send_throttle_timeout(self) -> float: + if len(self._write_buffer) or not self._write_throttling: + return 0 + + self.throttle_prune() + if self.throttle_space() > 0: + return 0 + + time_left = self._recent_sends[0]+THROTTLE_SECONDS + time_left = time_left-time.monotonic() + return time_left + + def set_write_throttling(self, is_on: bool): + self._write_throttling = is_on