Pass connection parameters around in their own object (IRCConnectionParameters)

This commit is contained in:
jesopo 2018-11-05 18:23:02 +00:00
parent e26df7556a
commit 6b8593a09b
5 changed files with 61 additions and 40 deletions

View file

@ -38,7 +38,8 @@ class Module(ModuleManager.BaseModule):
if not event["server"].tls:
self._set_policy(event["server"], int(info["port"]),
None, True)
event["server"].disconnect()
self.bot.reconnect(event["server"].id,
event["server"].connection_params)
else:
self._change_duration(event["server"], info)

View file

@ -12,18 +12,20 @@ class Servers(Table):
username = username or nickname
realname = realname or nickname
self.database.execute(
"""INSERT INTO servers (alias, hostname, port, password, ipv4,
tls, bindhost, nickname, username, realname) VALUES (
"""INSERT INTO servers (alias, hostname, port, password, tls,
ipv4, bindhost, nickname, username, realname) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
[alias, hostname, port, password, ipv4, tls, bindhost, nickname,
[alias, hostname, port, password, tls, ipv4, bindhost, nickname,
username, realname])
def get_all(self):
return self.database.execute_fetchall(
"SELECT server_id, alias FROM servers")
def get(self, id: int):
def get(self, id: int) -> typing.Tuple[int, typing.Optional[str], str,
int, typing.Optional[str], bool, bool, typing.Optional[str], str,
typing.Optional[str], typing.Optional[str]]:
return self.database.execute_fetchone(
"""SELECT server_id, alias, hostname, port, password, ipv4,
tls, bindhost, nickname, username, realname FROM servers WHERE
"""SELECT server_id, alias, hostname, port, password, tls,
ipv4, bindhost, nickname, username, realname FROM servers WHERE
server_id=?""",
[id])

View file

@ -36,14 +36,15 @@ class Bot(object):
self._trigger_client.send(b"TRIGGER")
self.lock.release()
def add_server(self, server_id: int, connect: bool = True
) -> IRCServer.Server:
(_, alias, hostname, port, password, ipv4, tls, bindhost, nickname,
username, realname) = self.database.servers.get(server_id)
def add_server(self, server_id: int, connect: bool = True,
connection_params: typing.Optional[
utils.irc.IRCConnectionParameters]=None) -> IRCServer.Server:
if not connection_params:
connection_params = utils.irc.IRCConnectionParameters(
*self.database.servers.get(server_id))
new_server = IRCServer.Server(self, self._events, server_id, alias,
hostname, port, password, ipv4, tls, bindhost, nickname, username,
realname)
new_server = IRCServer.Server(self, self._events,
connection_params.id, connection_params)
self._events.on("new.server").call(server=new_server)
if not connect or not new_server.get_setting("connect", True):
@ -129,10 +130,15 @@ class Bot(object):
del self.servers[server.fileno()]
def _timed_reconnect(self, event: EventManager.Event):
if not self.reconnect(event["server_id"]):
if not self.reconnect(event["server_id"], event["connection_params"]):
event["timer"].redo()
def reconnect(self, server_id: int) -> bool:
server = self.add_server(server_id, False)
def reconnect(self, server_id: int, connection_params: typing.Optional[
utils.irc.IRCConnectionParameters]=None) -> bool:
old_server = self.get_server(server_id)
if old_server:
self.disconnect(old_server)
server = self.add_server(server_id, False, connection_params)
if self.connect(server):
self.servers[server.fileno()] = server
return True
@ -204,7 +210,8 @@ class Bot(object):
reconnect_delay = self.config.get("reconnect-delay", 10)
self._timers.add("reconnect", reconnect_delay,
server_id=server.id)
server_id=server.id,
connection_params=server.connection_params)
print("disconnected from %s, reconnecting in %d seconds" % (
str(server), reconnect_delay))

View file

@ -10,23 +10,13 @@ class Server(IRCObject.Object):
def __init__(self,
bot: "IRCBot.Bot",
events: EventManager.EventHook,
id: int, alias: str, hostname: str, port: int, password: str,
ipv4: bool, tls: bool, bindhost: str,
nickname: str, username: str, realname: str):
id: int,
connection_params: utils.irc.IRCConnectionParameters):
self.connected = False
self.bot = bot
self.events = events
self.id = id
self.alias = alias
self.target_hostname = hostname
self.port = port
self.tls = tls
self.password = password
self.ipv4 = ipv4
self.bindhost = bindhost
self.original_nickname = nickname
self.original_username = username or nickname
self.original_realname = realname or nickname
self.connection_params = connection_params
self.name = None # type: typing.Optional[str]
self._capability_queue = set([]) # type: typing.Set[str]
@ -67,8 +57,8 @@ class Server(IRCObject.Object):
def __str__(self):
if self.alias:
return self.alias
return "%s:%s%s" % (self.target_hostname, "+" if self.tls else "",
self.port)
return "%s:%s%s" % (self.connection_params.hostname,
"+" if self.tls else "", self.port)
def fileno(self):
return self.cached_fileno or self.socket.fileno()
@ -90,22 +80,27 @@ class Server(IRCObject.Object):
self.socket = context.wrap_socket(self.socket)
def connect(self):
family = socket.AF_INET if self.ipv4 else socket.AF_INET6
ipv4 = self.connection_params.ipv4
family = socket.AF_INET if ipv4 else socket.AF_INET6
self.socket = socket.socket(family, socket.SOCK_STREAM)
self.socket.settimeout(5.0)
if self.bindhost:
if self.connection_params.bindhost:
self.socket.bind((self.bindhost, 0))
if self.tls:
if self.connection_params.tls:
self.tls_wrap()
self.socket.connect((self.target_hostname, self.port))
self.socket.connect((self.connection_params.hostname,
self.connection_params.port))
self.send_capibility_ls()
if self.password:
self.send_pass(self.password)
self.send_pass(self.connection_params.password)
self.send_user(self.original_username, self.original_realname)
self.send_nick(self.original_nickname)
self.send_user(self.connection_params.username,
self.connection_params.realname)
self.send_nick(self.connection_params.nickname)
self.connected = True
def disconnect(self):
self.cached_fileno = self.socket.fileno()

View file

@ -142,3 +142,19 @@ def strip_font(s: str) -> str:
s = s.replace(FONT_COLOR, "")
return s
OPT_STR = typing.Optional[str]
class IRCConnectionParameters(object):
def __init__(self, id: int, alias: OPT_STR, hostname: str, port: int,
tls: bool, ipv4: bool, password: OPT_STR, bindhost: OPT_STR,
nickname: str, username: OPT_STR, realname: OPT_STR):
self.id = id
self.alias = alias
self.hostname = hostname
self.port = port
self.tls = tls
self.ipv4 = ipv4
self.bindhost = bindhost
self.password = password
self.nickname = nickname
self.username = username
self.realname = realname