Shift socket.socket related logic to IRCSocket.py
This commit is contained in:
parent
b6b7345764
commit
9b44b6cd13
4 changed files with 210 additions and 162 deletions
|
@ -69,7 +69,7 @@ class Module(ModuleManager.BaseModule):
|
|||
# first numeric line the server sends
|
||||
@utils.hook("raw.received.001", default_event=True)
|
||||
def handle_001(self, event):
|
||||
event["server"].set_write_throttling(True)
|
||||
event["server"].socket.set_write_throttling(True)
|
||||
event["server"].name = event["prefix"].hostmask
|
||||
event["server"].set_own_nickname(event["args"][0])
|
||||
event["server"].send_whois(event["server"].nickname)
|
||||
|
|
|
@ -101,8 +101,9 @@ class Bot(object):
|
|||
def next_send(self) -> typing.Optional[float]:
|
||||
next = None
|
||||
for server in self.servers.values():
|
||||
timeout = server.send_throttle_timeout()
|
||||
if server.waiting_send() and (next == None or timeout < next):
|
||||
timeout = server.socket.send_throttle_timeout()
|
||||
if (server.socket.waiting_send() and
|
||||
(next == None or timeout < next)):
|
||||
next = timeout
|
||||
return next
|
||||
|
||||
|
@ -238,7 +239,8 @@ class Bot(object):
|
|||
self.log.warn(
|
||||
"Disconnected from %s, reconnecting in %d seconds",
|
||||
[str(server), reconnect_delay])
|
||||
elif server.waiting_send() and server.throttle_done():
|
||||
elif (server.socket.waiting_send() and
|
||||
server.socket.throttle_done()):
|
||||
self.register_both(server)
|
||||
|
||||
for sock in list(self.other_sockets.values()):
|
||||
|
|
185
src/IRCServer.py
185
src/IRCServer.py
|
@ -1,10 +1,7 @@
|
|||
import collections, datetime, socket, ssl, sys, time, typing
|
||||
from src import EventManager, IRCBot, IRCChannel, IRCChannels, IRCLine
|
||||
from src import IRCObject, IRCUser, utils
|
||||
|
||||
THROTTLE_LINES = 4
|
||||
THROTTLE_SECONDS = 1
|
||||
UNTHROTTLED_MAX_LINES = 10
|
||||
import collections, datetime, sys, time, typing
|
||||
from src import EventManager, IRCBot, IRCChannel, IRCChannels, IRCLine
|
||||
from src import IRCObject, IRCSocket, IRCUser, utils
|
||||
|
||||
READ_TIMEOUT_SECONDS = 120
|
||||
PING_INTERVAL_SECONDS = 30
|
||||
|
@ -37,16 +34,6 @@ class Server(IRCObject.Object):
|
|||
self.batches = {} # type: typing.Dict[str, utils.irc.IRCParsedLine]
|
||||
self.cap_started = False
|
||||
|
||||
self.write_buffer = b""
|
||||
self.queued_lines = [] # type: typing.List[IRCLine.Line]
|
||||
self.buffered_lines = [] # type: typing.List[IRCLine.Line]
|
||||
self._write_throttling = False
|
||||
self.read_buffer = b""
|
||||
self.recent_sends = [] # type: typing.List[float]
|
||||
self.cached_fileno = None # type: typing.Optional[int]
|
||||
self.bytes_written = 0
|
||||
self.bytes_read = 0
|
||||
|
||||
self.users = {} # type: typing.Dict[str, IRCUser.User]
|
||||
self.new_users = set([]) #type: typing.Set[IRCUser.User]
|
||||
self.channels = IRCChannels.Channels(self, self.bot, self.events)
|
||||
|
@ -88,40 +75,27 @@ class Server(IRCObject.Object):
|
|||
return "%s:%s%s" % (self.connection_params.hostname,
|
||||
"+" if self.connection_params.tls else "",
|
||||
self.connection_params.port)
|
||||
|
||||
def fileno(self) -> int:
|
||||
return self.cached_fileno or self.socket.fileno()
|
||||
return self.socket.fileno()
|
||||
|
||||
def hostmask(self):
|
||||
return "%s!%s@%s" % (self.nickname, self.username, self.hostname)
|
||||
|
||||
def tls_wrap(self):
|
||||
client_certificate = self.bot.config.get("tls-certificate", None)
|
||||
client_key = self.bot.config.get("tls-key", None)
|
||||
verify = self.get_setting("ssl-verify", True)
|
||||
|
||||
server_hostname = None
|
||||
if not utils.is_ip(self.connection_params.hostname):
|
||||
server_hostname = self.connection_params.hostname
|
||||
|
||||
self.socket = utils.security.ssl_wrap(self.socket,
|
||||
cert=client_certificate, key=client_key,
|
||||
verify=verify, hostname=server_hostname)
|
||||
|
||||
def connect(self):
|
||||
ipv4 = self.connection_params.ipv4
|
||||
family = socket.AF_INET if ipv4 else socket.AF_INET6
|
||||
self.socket = socket.socket(family, socket.SOCK_STREAM)
|
||||
|
||||
self.socket.settimeout(5.0)
|
||||
|
||||
if self.connection_params.bindhost:
|
||||
self.socket.bind((self.connection_params.bindhost, 0))
|
||||
if self.connection_params.tls:
|
||||
self.tls_wrap()
|
||||
|
||||
self.socket.connect((self.connection_params.hostname,
|
||||
self.connection_params.port))
|
||||
self.cached_fileno = self.socket.fileno()
|
||||
self.socket = IRCSocket.Socket(
|
||||
self.bot.log,
|
||||
self.get_setting("encoding", "utf8"),
|
||||
self.get_setting("fallback-encoding", "iso-8859-1"),
|
||||
self.connection_params.hostname,
|
||||
self.connection_params.port,
|
||||
self.connection_params.ipv4,
|
||||
self.connection_params.bindhost,
|
||||
self.connection_params.tls,
|
||||
tls_verify=self.get_setting("ssl-verify", True),
|
||||
cert=self.bot.config.get("tls-certificate", None),
|
||||
key=self.bot.config.get("tls-key", None))
|
||||
self.socket.connect()
|
||||
|
||||
if self.connection_params.password:
|
||||
self.send_pass(self.connection_params.password)
|
||||
|
@ -135,16 +109,9 @@ class Server(IRCObject.Object):
|
|||
self.send_user(username, realname)
|
||||
self.send_nick(nickname)
|
||||
self.connected = True
|
||||
|
||||
def disconnect(self):
|
||||
self.connected = False
|
||||
try:
|
||||
self.socket.shutdown(socket.SHUT_RDWR)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
self.socket.close()
|
||||
except:
|
||||
pass
|
||||
self.socket.disconnect()
|
||||
|
||||
def set_setting(self, setting: str, value: typing.Any):
|
||||
self.bot.database.server_settings.set(self.id, setting,
|
||||
|
@ -252,46 +219,6 @@ class Server(IRCObject.Object):
|
|||
if not len(user.channels):
|
||||
self.remove_user(user)
|
||||
self.new_users.clear()
|
||||
def read(self) -> typing.Optional[typing.List[str]]:
|
||||
data = b""
|
||||
try:
|
||||
data = self.socket.recv(4096)
|
||||
except (ConnectionResetError, socket.timeout, OSError):
|
||||
self.disconnect()
|
||||
return None
|
||||
if not data:
|
||||
self.disconnect()
|
||||
return None
|
||||
self.bytes_read += len(data)
|
||||
data = self.read_buffer+data
|
||||
self.read_buffer = b""
|
||||
|
||||
data_lines = [line.strip(b"\r") for line in data.split(b"\n")]
|
||||
if data_lines[-1]:
|
||||
self.read_buffer = data_lines[-1]
|
||||
self.bot.log.trace("recevied and buffered non-complete line: %s",
|
||||
[data_lines[-1]])
|
||||
|
||||
data_lines.pop(-1)
|
||||
decoded_lines = []
|
||||
|
||||
for line in data_lines:
|
||||
encoding = self.get_setting("encoding", "utf8")
|
||||
try:
|
||||
decoded_line = line.decode(encoding)
|
||||
except:
|
||||
self.bot.log.trace("can't decode line with '%s', falling back",
|
||||
[encoding])
|
||||
try:
|
||||
decoded_line = line.decode(self.get_setting(
|
||||
"fallback-encoding", "latin-1"))
|
||||
except:
|
||||
continue
|
||||
decoded_lines.append(decoded_line)
|
||||
|
||||
self.last_read = time.monotonic()
|
||||
self.ping_sent = False
|
||||
return decoded_lines
|
||||
|
||||
def until_next_ping(self) -> typing.Optional[float]:
|
||||
if self.ping_sent:
|
||||
|
@ -307,6 +234,8 @@ class Server(IRCObject.Object):
|
|||
def read_timed_out(self) -> bool:
|
||||
return self.until_read_timeout == 0
|
||||
|
||||
def read(self) -> typing.Optional[typing.List[str]]:
|
||||
return self.socket.read()
|
||||
def send(self, line: str):
|
||||
results = self.events.on("preprocess.send").call_unsafe(
|
||||
server=self, line=line)
|
||||
|
@ -314,75 +243,15 @@ class Server(IRCObject.Object):
|
|||
if result:
|
||||
line = result
|
||||
break
|
||||
|
||||
line_stripped = line.split("\n", 1)[0].strip("\r")
|
||||
line_obj = IRCLine.Line(self, datetime.datetime.utcnow(), line_stripped)
|
||||
self.queued_lines.append(line_obj)
|
||||
|
||||
self.socket.send(line_obj)
|
||||
return line_obj
|
||||
|
||||
def _send(self):
|
||||
if not len(self.write_buffer):
|
||||
throttle_space = self.throttle_space()
|
||||
to_buffer = self.queued_lines[:throttle_space]
|
||||
self.queued_lines = self.queued_lines[throttle_space:]
|
||||
for line in to_buffer:
|
||||
decoded_data = line.decoded_data()
|
||||
self.bot.log.debug("%s (raw send) | %s",
|
||||
[str(self), decoded_data])
|
||||
self.events.on("raw.send").call_unsafe(
|
||||
server=self, line=decoded_data)
|
||||
|
||||
self.write_buffer += line.data()
|
||||
self.buffered_lines.append(line)
|
||||
|
||||
bytes_written_i = self.socket.send(self.write_buffer)
|
||||
bytes_written = self.write_buffer[:bytes_written_i]
|
||||
lines_sent = bytes_written.count(b"\r\n")
|
||||
for i in range(lines_sent):
|
||||
self.buffered_lines.pop(0).sent()
|
||||
|
||||
self.write_buffer = self.write_buffer[bytes_written_i:]
|
||||
|
||||
self.bytes_written += bytes_written_i
|
||||
|
||||
now = time.monotonic()
|
||||
self.recent_sends.append(now)
|
||||
self.last_send = now
|
||||
def waiting_send(self) -> bool:
|
||||
return bool(len(self.write_buffer)) or bool(len(self.queued_lines))
|
||||
|
||||
def throttle_done(self) -> bool:
|
||||
return self.send_throttle_timeout() == 0
|
||||
|
||||
def throttle_prune(self):
|
||||
now = time.monotonic()
|
||||
popped = 0
|
||||
for i, recent_send in enumerate(self.recent_sends[:]):
|
||||
time_since = now-recent_send
|
||||
if time_since >= THROTTLE_SECONDS:
|
||||
self.recent_sends.pop(i-popped)
|
||||
popped += 1
|
||||
|
||||
def throttle_space(self) -> int:
|
||||
if not self._write_throttling:
|
||||
return UNTHROTTLED_MAX_LINES
|
||||
return max(0, THROTTLE_LINES-len(self.recent_sends))
|
||||
|
||||
def send_throttle_timeout(self) -> float:
|
||||
if len(self.write_buffer) or not self._write_throttling:
|
||||
return 0
|
||||
|
||||
self.throttle_prune()
|
||||
if self.throttle_space() > 0:
|
||||
return 0
|
||||
|
||||
time_left = self.recent_sends[0]+THROTTLE_SECONDS
|
||||
time_left = time_left-time.monotonic()
|
||||
return time_left
|
||||
|
||||
def set_write_throttling(self, is_on: bool):
|
||||
self._write_throttling = is_on
|
||||
lines = self.socket._send()
|
||||
for line in lines:
|
||||
self.bot.log.debug("%s (raw send) | %s", [str(self), line])
|
||||
self.events.on("raw.send").call_unsafe(server=self, line=line)
|
||||
|
||||
def send_user(self, username: str, realname: str) -> IRCLine.Line:
|
||||
return self.send("USER %s 0 * :%s" % (username, realname))
|
||||
|
|
177
src/IRCSocket.py
Normal file
177
src/IRCSocket.py
Normal file
|
@ -0,0 +1,177 @@
|
|||
import datetime, socket, ssl, time, typing
|
||||
from src import IRCLine, Logging, IRCObject, utils
|
||||
|
||||
THROTTLE_LINES = 4
|
||||
THROTTLE_SECONDS = 1
|
||||
UNTHROTTLED_MAX_LINES = 10
|
||||
|
||||
class Socket(IRCObject.Object):
|
||||
def __init__(self, log: Logging.Log, encoding: str, fallback_encoding: str,
|
||||
hostname: str, port: int, ipv4: bool, bindhost: str, tls: bool,
|
||||
tls_verify: bool=True, cert: str=None, key: str=None):
|
||||
self.log = log
|
||||
|
||||
self._encoding = encoding
|
||||
self._fallback_encoding = fallback_encoding
|
||||
self._hostname = hostname
|
||||
self._port = port
|
||||
self._ipv4 = ipv4
|
||||
self._bindhost = bindhost
|
||||
|
||||
self._tls = tls
|
||||
self._tls_verify = tls_verify
|
||||
self._cert = cert
|
||||
self._key = key
|
||||
|
||||
self._write_buffer = b""
|
||||
self._queued_lines = [] # type: typing.List[IRCLine.Line]
|
||||
self._buffered_lines = [] # type: typing.List[IRCLine.Line]
|
||||
self._write_throttling = False
|
||||
self._read_buffer = b""
|
||||
self._recent_sends = [] # type: typing.List[float]
|
||||
self.cached_fileno = None # type: typing.Optional[int]
|
||||
self.bytes_written = 0
|
||||
self.bytes_read = 0
|
||||
|
||||
def fileno(self) -> int:
|
||||
return self.cached_fileno or self._socket.fileno()
|
||||
|
||||
def _tls_wrap(self):
|
||||
server_hostname = None
|
||||
if not utils.is_ip(self._hostname):
|
||||
server_hostname = self._hostname
|
||||
|
||||
self._socket = utils.security.ssl_wrap(self._socket,
|
||||
cert=self._cert, key=self._key, verify=self._tls_verify,
|
||||
hostname=server_hostname)
|
||||
|
||||
def connect(self):
|
||||
family = socket.AF_INET if self._ipv4 else socket.AF_INET6
|
||||
self._socket = socket.socket(family, socket.SOCK_STREAM)
|
||||
|
||||
self._socket.settimeout(5.0)
|
||||
|
||||
if self._bindhost:
|
||||
self._socket.bind((self._bindhost, 0))
|
||||
if self._tls:
|
||||
self._tls_wrap()
|
||||
|
||||
self._socket.connect((self._hostname, self._port))
|
||||
self.cached_fileno = self._socket.fileno()
|
||||
|
||||
def disconnect(self):
|
||||
self.connected = False
|
||||
try:
|
||||
self._socket.shutdown(socket.SHUT_RDWR)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
self._socket.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
def read(self) -> typing.Optional[typing.List[str]]:
|
||||
data = b""
|
||||
try:
|
||||
data = self._socket.recv(4096)
|
||||
except (ConnectionResetError, socket.timeout, OSError):
|
||||
self.disconnect()
|
||||
return None
|
||||
if not data:
|
||||
self.disconnect()
|
||||
return None
|
||||
self.bytes_read += len(data)
|
||||
data = self._read_buffer+data
|
||||
self._read_buffer = b""
|
||||
|
||||
data_lines = [line.strip(b"\r") for line in data.split(b"\n")]
|
||||
if data_lines[-1]:
|
||||
self._read_buffer = data_lines[-1]
|
||||
self.log.trace("recevied and buffered non-complete line: %s",
|
||||
[data_lines[-1]])
|
||||
|
||||
data_lines.pop(-1)
|
||||
decoded_lines = []
|
||||
|
||||
for line in data_lines:
|
||||
try:
|
||||
decoded_line = line.decode(self._encoding)
|
||||
except:
|
||||
self.log.trace("can't decode line with '%s', falling back",
|
||||
[self._encoding])
|
||||
try:
|
||||
decoded_line = line.decode(self._fallback_encoding)
|
||||
except:
|
||||
continue
|
||||
decoded_lines.append(decoded_line)
|
||||
|
||||
self.last_read = time.monotonic()
|
||||
self.ping_sent = False
|
||||
return decoded_lines
|
||||
|
||||
def send(self, line: IRCLine.Line):
|
||||
self._queued_lines.append(line)
|
||||
|
||||
def _send(self) -> typing.List[str]:
|
||||
decoded_sent = []
|
||||
if not len(self._write_buffer):
|
||||
throttle_space = self.throttle_space()
|
||||
to_buffer = self._queued_lines[:throttle_space]
|
||||
self._queued_lines = self._queued_lines[throttle_space:]
|
||||
for line in to_buffer:
|
||||
decoded_data = line.decoded_data()
|
||||
decoded_sent.append(decoded_data)
|
||||
|
||||
self._write_buffer += line.data()
|
||||
self._buffered_lines.append(line)
|
||||
|
||||
bytes_written_i = self._socket.send(self._write_buffer)
|
||||
bytes_written = self._write_buffer[:bytes_written_i]
|
||||
lines_sent = bytes_written.count(b"\r\n")
|
||||
for i in range(lines_sent):
|
||||
self._buffered_lines.pop(0).sent()
|
||||
|
||||
self._write_buffer = self._write_buffer[bytes_written_i:]
|
||||
|
||||
self.bytes_written += bytes_written_i
|
||||
|
||||
now = time.monotonic()
|
||||
self._recent_sends.append(now)
|
||||
self.last_send = now
|
||||
|
||||
return decoded_sent
|
||||
|
||||
def waiting_send(self) -> bool:
|
||||
return bool(len(self._write_buffer)) or bool(len(self._queued_lines))
|
||||
|
||||
def throttle_done(self) -> bool:
|
||||
return self.send_throttle_timeout() == 0
|
||||
|
||||
def throttle_prune(self):
|
||||
now = time.monotonic()
|
||||
popped = 0
|
||||
for i, recent_send in enumerate(self._recent_sends[:]):
|
||||
time_since = now-recent_send
|
||||
if time_since >= THROTTLE_SECONDS:
|
||||
self._recent_sends.pop(i-popped)
|
||||
popped += 1
|
||||
|
||||
def throttle_space(self) -> int:
|
||||
if not self._write_throttling:
|
||||
return UNTHROTTLED_MAX_LINES
|
||||
return max(0, THROTTLE_LINES-len(self._recent_sends))
|
||||
|
||||
def send_throttle_timeout(self) -> float:
|
||||
if len(self._write_buffer) or not self._write_throttling:
|
||||
return 0
|
||||
|
||||
self.throttle_prune()
|
||||
if self.throttle_space() > 0:
|
||||
return 0
|
||||
|
||||
time_left = self._recent_sends[0]+THROTTLE_SECONDS
|
||||
time_left = time_left-time.monotonic()
|
||||
return time_left
|
||||
|
||||
def set_write_throttling(self, is_on: bool):
|
||||
self._write_throttling = is_on
|
Loading…
Reference in a new issue