From 6b8593a09b6340e902f2522ac0eaf4fb0357e7bd Mon Sep 17 00:00:00 2001 From: jesopo Date: Mon, 5 Nov 2018 18:23:02 +0000 Subject: [PATCH] Pass connection parameters around in their own object (IRCConnectionParameters) --- modules/sts.py | 3 ++- src/Database.py | 14 ++++++++------ src/IRCBot.py | 29 ++++++++++++++++++----------- src/IRCServer.py | 39 +++++++++++++++++---------------------- src/utils/irc.py | 16 ++++++++++++++++ 5 files changed, 61 insertions(+), 40 deletions(-) diff --git a/modules/sts.py b/modules/sts.py index fa361604..39257de1 100644 --- a/modules/sts.py +++ b/modules/sts.py @@ -38,7 +38,8 @@ class Module(ModuleManager.BaseModule): if not event["server"].tls: self._set_policy(event["server"], int(info["port"]), None, True) - event["server"].disconnect() + self.bot.reconnect(event["server"].id, + event["server"].connection_params) else: self._change_duration(event["server"], info) diff --git a/src/Database.py b/src/Database.py index 77e83fdc..8dcf42ec 100644 --- a/src/Database.py +++ b/src/Database.py @@ -12,18 +12,20 @@ class Servers(Table): username = username or nickname realname = realname or nickname self.database.execute( - """INSERT INTO servers (alias, hostname, port, password, ipv4, - tls, bindhost, nickname, username, realname) VALUES ( + """INSERT INTO servers (alias, hostname, port, password, tls, + ipv4, bindhost, nickname, username, realname) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - [alias, hostname, port, password, ipv4, tls, bindhost, nickname, + [alias, hostname, port, password, tls, ipv4, bindhost, nickname, username, realname]) def get_all(self): return self.database.execute_fetchall( "SELECT server_id, alias FROM servers") - def get(self, id: int): + def get(self, id: int) -> typing.Tuple[int, typing.Optional[str], str, + int, typing.Optional[str], bool, bool, typing.Optional[str], str, + typing.Optional[str], typing.Optional[str]]: return self.database.execute_fetchone( - """SELECT server_id, alias, hostname, port, password, ipv4, - tls, bindhost, nickname, username, realname FROM servers WHERE + """SELECT server_id, alias, hostname, port, password, tls, + ipv4, bindhost, nickname, username, realname FROM servers WHERE server_id=?""", [id]) diff --git a/src/IRCBot.py b/src/IRCBot.py index 95c8f617..69bb3abb 100644 --- a/src/IRCBot.py +++ b/src/IRCBot.py @@ -36,14 +36,15 @@ class Bot(object): self._trigger_client.send(b"TRIGGER") self.lock.release() - def add_server(self, server_id: int, connect: bool = True - ) -> IRCServer.Server: - (_, alias, hostname, port, password, ipv4, tls, bindhost, nickname, - username, realname) = self.database.servers.get(server_id) + def add_server(self, server_id: int, connect: bool = True, + connection_params: typing.Optional[ + utils.irc.IRCConnectionParameters]=None) -> IRCServer.Server: + if not connection_params: + connection_params = utils.irc.IRCConnectionParameters( + *self.database.servers.get(server_id)) - new_server = IRCServer.Server(self, self._events, server_id, alias, - hostname, port, password, ipv4, tls, bindhost, nickname, username, - realname) + new_server = IRCServer.Server(self, self._events, + connection_params.id, connection_params) self._events.on("new.server").call(server=new_server) if not connect or not new_server.get_setting("connect", True): @@ -129,10 +130,15 @@ class Bot(object): del self.servers[server.fileno()] def _timed_reconnect(self, event: EventManager.Event): - if not self.reconnect(event["server_id"]): + if not self.reconnect(event["server_id"], event["connection_params"]): event["timer"].redo() - def reconnect(self, server_id: int) -> bool: - server = self.add_server(server_id, False) + def reconnect(self, server_id: int, connection_params: typing.Optional[ + utils.irc.IRCConnectionParameters]=None) -> bool: + old_server = self.get_server(server_id) + if old_server: + self.disconnect(old_server) + + server = self.add_server(server_id, False, connection_params) if self.connect(server): self.servers[server.fileno()] = server return True @@ -204,7 +210,8 @@ class Bot(object): reconnect_delay = self.config.get("reconnect-delay", 10) self._timers.add("reconnect", reconnect_delay, - server_id=server.id) + server_id=server.id, + connection_params=server.connection_params) print("disconnected from %s, reconnecting in %d seconds" % ( str(server), reconnect_delay)) diff --git a/src/IRCServer.py b/src/IRCServer.py index af04b5f4..b2e1607f 100644 --- a/src/IRCServer.py +++ b/src/IRCServer.py @@ -10,23 +10,13 @@ class Server(IRCObject.Object): def __init__(self, bot: "IRCBot.Bot", events: EventManager.EventHook, - id: int, alias: str, hostname: str, port: int, password: str, - ipv4: bool, tls: bool, bindhost: str, - nickname: str, username: str, realname: str): + id: int, + connection_params: utils.irc.IRCConnectionParameters): self.connected = False self.bot = bot self.events = events self.id = id - self.alias = alias - self.target_hostname = hostname - self.port = port - self.tls = tls - self.password = password - self.ipv4 = ipv4 - self.bindhost = bindhost - self.original_nickname = nickname - self.original_username = username or nickname - self.original_realname = realname or nickname + self.connection_params = connection_params self.name = None # type: typing.Optional[str] self._capability_queue = set([]) # type: typing.Set[str] @@ -67,8 +57,8 @@ class Server(IRCObject.Object): def __str__(self): if self.alias: return self.alias - return "%s:%s%s" % (self.target_hostname, "+" if self.tls else "", - self.port) + return "%s:%s%s" % (self.connection_params.hostname, + "+" if self.tls else "", self.port) def fileno(self): return self.cached_fileno or self.socket.fileno() @@ -90,22 +80,27 @@ class Server(IRCObject.Object): self.socket = context.wrap_socket(self.socket) def connect(self): - family = socket.AF_INET if self.ipv4 else socket.AF_INET6 + ipv4 = self.connection_params.ipv4 + family = socket.AF_INET if ipv4 else socket.AF_INET6 self.socket = socket.socket(family, socket.SOCK_STREAM) + self.socket.settimeout(5.0) - if self.bindhost: + + if self.connection_params.bindhost: self.socket.bind((self.bindhost, 0)) - if self.tls: + if self.connection_params.tls: self.tls_wrap() - self.socket.connect((self.target_hostname, self.port)) + self.socket.connect((self.connection_params.hostname, + self.connection_params.port)) self.send_capibility_ls() if self.password: - self.send_pass(self.password) + self.send_pass(self.connection_params.password) - self.send_user(self.original_username, self.original_realname) - self.send_nick(self.original_nickname) + self.send_user(self.connection_params.username, + self.connection_params.realname) + self.send_nick(self.connection_params.nickname) self.connected = True def disconnect(self): self.cached_fileno = self.socket.fileno() diff --git a/src/utils/irc.py b/src/utils/irc.py index 4070f97e..c78dd179 100644 --- a/src/utils/irc.py +++ b/src/utils/irc.py @@ -142,3 +142,19 @@ def strip_font(s: str) -> str: s = s.replace(FONT_COLOR, "") return s +OPT_STR = typing.Optional[str] +class IRCConnectionParameters(object): + def __init__(self, id: int, alias: OPT_STR, hostname: str, port: int, + tls: bool, ipv4: bool, password: OPT_STR, bindhost: OPT_STR, + nickname: str, username: OPT_STR, realname: OPT_STR): + self.id = id + self.alias = alias + self.hostname = hostname + self.port = port + self.tls = tls + self.ipv4 = ipv4 + self.bindhost = bindhost + self.password = password + self.nickname = nickname + self.username = username + self.realname = realname