Pass connection parameters around in their own object (IRCConnectionParameters)
This commit is contained in:
parent
e26df7556a
commit
6b8593a09b
5 changed files with 61 additions and 40 deletions
|
@ -38,7 +38,8 @@ class Module(ModuleManager.BaseModule):
|
||||||
if not event["server"].tls:
|
if not event["server"].tls:
|
||||||
self._set_policy(event["server"], int(info["port"]),
|
self._set_policy(event["server"], int(info["port"]),
|
||||||
None, True)
|
None, True)
|
||||||
event["server"].disconnect()
|
self.bot.reconnect(event["server"].id,
|
||||||
|
event["server"].connection_params)
|
||||||
else:
|
else:
|
||||||
self._change_duration(event["server"], info)
|
self._change_duration(event["server"], info)
|
||||||
|
|
||||||
|
|
|
@ -12,18 +12,20 @@ class Servers(Table):
|
||||||
username = username or nickname
|
username = username or nickname
|
||||||
realname = realname or nickname
|
realname = realname or nickname
|
||||||
self.database.execute(
|
self.database.execute(
|
||||||
"""INSERT INTO servers (alias, hostname, port, password, ipv4,
|
"""INSERT INTO servers (alias, hostname, port, password, tls,
|
||||||
tls, bindhost, nickname, username, realname) VALUES (
|
ipv4, bindhost, nickname, username, realname) VALUES (
|
||||||
?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||||
[alias, hostname, port, password, ipv4, tls, bindhost, nickname,
|
[alias, hostname, port, password, tls, ipv4, bindhost, nickname,
|
||||||
username, realname])
|
username, realname])
|
||||||
def get_all(self):
|
def get_all(self):
|
||||||
return self.database.execute_fetchall(
|
return self.database.execute_fetchall(
|
||||||
"SELECT server_id, alias FROM servers")
|
"SELECT server_id, alias FROM servers")
|
||||||
def get(self, id: int):
|
def get(self, id: int) -> typing.Tuple[int, typing.Optional[str], str,
|
||||||
|
int, typing.Optional[str], bool, bool, typing.Optional[str], str,
|
||||||
|
typing.Optional[str], typing.Optional[str]]:
|
||||||
return self.database.execute_fetchone(
|
return self.database.execute_fetchone(
|
||||||
"""SELECT server_id, alias, hostname, port, password, ipv4,
|
"""SELECT server_id, alias, hostname, port, password, tls,
|
||||||
tls, bindhost, nickname, username, realname FROM servers WHERE
|
ipv4, bindhost, nickname, username, realname FROM servers WHERE
|
||||||
server_id=?""",
|
server_id=?""",
|
||||||
[id])
|
[id])
|
||||||
|
|
||||||
|
|
|
@ -36,14 +36,15 @@ class Bot(object):
|
||||||
self._trigger_client.send(b"TRIGGER")
|
self._trigger_client.send(b"TRIGGER")
|
||||||
self.lock.release()
|
self.lock.release()
|
||||||
|
|
||||||
def add_server(self, server_id: int, connect: bool = True
|
def add_server(self, server_id: int, connect: bool = True,
|
||||||
) -> IRCServer.Server:
|
connection_params: typing.Optional[
|
||||||
(_, alias, hostname, port, password, ipv4, tls, bindhost, nickname,
|
utils.irc.IRCConnectionParameters]=None) -> IRCServer.Server:
|
||||||
username, realname) = self.database.servers.get(server_id)
|
if not connection_params:
|
||||||
|
connection_params = utils.irc.IRCConnectionParameters(
|
||||||
|
*self.database.servers.get(server_id))
|
||||||
|
|
||||||
new_server = IRCServer.Server(self, self._events, server_id, alias,
|
new_server = IRCServer.Server(self, self._events,
|
||||||
hostname, port, password, ipv4, tls, bindhost, nickname, username,
|
connection_params.id, connection_params)
|
||||||
realname)
|
|
||||||
self._events.on("new.server").call(server=new_server)
|
self._events.on("new.server").call(server=new_server)
|
||||||
|
|
||||||
if not connect or not new_server.get_setting("connect", True):
|
if not connect or not new_server.get_setting("connect", True):
|
||||||
|
@ -129,10 +130,15 @@ class Bot(object):
|
||||||
del self.servers[server.fileno()]
|
del self.servers[server.fileno()]
|
||||||
|
|
||||||
def _timed_reconnect(self, event: EventManager.Event):
|
def _timed_reconnect(self, event: EventManager.Event):
|
||||||
if not self.reconnect(event["server_id"]):
|
if not self.reconnect(event["server_id"], event["connection_params"]):
|
||||||
event["timer"].redo()
|
event["timer"].redo()
|
||||||
def reconnect(self, server_id: int) -> bool:
|
def reconnect(self, server_id: int, connection_params: typing.Optional[
|
||||||
server = self.add_server(server_id, False)
|
utils.irc.IRCConnectionParameters]=None) -> bool:
|
||||||
|
old_server = self.get_server(server_id)
|
||||||
|
if old_server:
|
||||||
|
self.disconnect(old_server)
|
||||||
|
|
||||||
|
server = self.add_server(server_id, False, connection_params)
|
||||||
if self.connect(server):
|
if self.connect(server):
|
||||||
self.servers[server.fileno()] = server
|
self.servers[server.fileno()] = server
|
||||||
return True
|
return True
|
||||||
|
@ -204,7 +210,8 @@ class Bot(object):
|
||||||
|
|
||||||
reconnect_delay = self.config.get("reconnect-delay", 10)
|
reconnect_delay = self.config.get("reconnect-delay", 10)
|
||||||
self._timers.add("reconnect", reconnect_delay,
|
self._timers.add("reconnect", reconnect_delay,
|
||||||
server_id=server.id)
|
server_id=server.id,
|
||||||
|
connection_params=server.connection_params)
|
||||||
|
|
||||||
print("disconnected from %s, reconnecting in %d seconds" % (
|
print("disconnected from %s, reconnecting in %d seconds" % (
|
||||||
str(server), reconnect_delay))
|
str(server), reconnect_delay))
|
||||||
|
|
|
@ -10,23 +10,13 @@ class Server(IRCObject.Object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
bot: "IRCBot.Bot",
|
bot: "IRCBot.Bot",
|
||||||
events: EventManager.EventHook,
|
events: EventManager.EventHook,
|
||||||
id: int, alias: str, hostname: str, port: int, password: str,
|
id: int,
|
||||||
ipv4: bool, tls: bool, bindhost: str,
|
connection_params: utils.irc.IRCConnectionParameters):
|
||||||
nickname: str, username: str, realname: str):
|
|
||||||
self.connected = False
|
self.connected = False
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.events = events
|
self.events = events
|
||||||
self.id = id
|
self.id = id
|
||||||
self.alias = alias
|
self.connection_params = connection_params
|
||||||
self.target_hostname = hostname
|
|
||||||
self.port = port
|
|
||||||
self.tls = tls
|
|
||||||
self.password = password
|
|
||||||
self.ipv4 = ipv4
|
|
||||||
self.bindhost = bindhost
|
|
||||||
self.original_nickname = nickname
|
|
||||||
self.original_username = username or nickname
|
|
||||||
self.original_realname = realname or nickname
|
|
||||||
self.name = None # type: typing.Optional[str]
|
self.name = None # type: typing.Optional[str]
|
||||||
|
|
||||||
self._capability_queue = set([]) # type: typing.Set[str]
|
self._capability_queue = set([]) # type: typing.Set[str]
|
||||||
|
@ -67,8 +57,8 @@ class Server(IRCObject.Object):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if self.alias:
|
if self.alias:
|
||||||
return self.alias
|
return self.alias
|
||||||
return "%s:%s%s" % (self.target_hostname, "+" if self.tls else "",
|
return "%s:%s%s" % (self.connection_params.hostname,
|
||||||
self.port)
|
"+" if self.tls else "", self.port)
|
||||||
def fileno(self):
|
def fileno(self):
|
||||||
return self.cached_fileno or self.socket.fileno()
|
return self.cached_fileno or self.socket.fileno()
|
||||||
|
|
||||||
|
@ -90,22 +80,27 @@ class Server(IRCObject.Object):
|
||||||
self.socket = context.wrap_socket(self.socket)
|
self.socket = context.wrap_socket(self.socket)
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
family = socket.AF_INET if self.ipv4 else socket.AF_INET6
|
ipv4 = self.connection_params.ipv4
|
||||||
|
family = socket.AF_INET if ipv4 else socket.AF_INET6
|
||||||
self.socket = socket.socket(family, socket.SOCK_STREAM)
|
self.socket = socket.socket(family, socket.SOCK_STREAM)
|
||||||
|
|
||||||
self.socket.settimeout(5.0)
|
self.socket.settimeout(5.0)
|
||||||
if self.bindhost:
|
|
||||||
|
if self.connection_params.bindhost:
|
||||||
self.socket.bind((self.bindhost, 0))
|
self.socket.bind((self.bindhost, 0))
|
||||||
if self.tls:
|
if self.connection_params.tls:
|
||||||
self.tls_wrap()
|
self.tls_wrap()
|
||||||
|
|
||||||
self.socket.connect((self.target_hostname, self.port))
|
self.socket.connect((self.connection_params.hostname,
|
||||||
|
self.connection_params.port))
|
||||||
self.send_capibility_ls()
|
self.send_capibility_ls()
|
||||||
|
|
||||||
if self.password:
|
if self.password:
|
||||||
self.send_pass(self.password)
|
self.send_pass(self.connection_params.password)
|
||||||
|
|
||||||
self.send_user(self.original_username, self.original_realname)
|
self.send_user(self.connection_params.username,
|
||||||
self.send_nick(self.original_nickname)
|
self.connection_params.realname)
|
||||||
|
self.send_nick(self.connection_params.nickname)
|
||||||
self.connected = True
|
self.connected = True
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
self.cached_fileno = self.socket.fileno()
|
self.cached_fileno = self.socket.fileno()
|
||||||
|
|
|
@ -142,3 +142,19 @@ def strip_font(s: str) -> str:
|
||||||
s = s.replace(FONT_COLOR, "")
|
s = s.replace(FONT_COLOR, "")
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
OPT_STR = typing.Optional[str]
|
||||||
|
class IRCConnectionParameters(object):
|
||||||
|
def __init__(self, id: int, alias: OPT_STR, hostname: str, port: int,
|
||||||
|
tls: bool, ipv4: bool, password: OPT_STR, bindhost: OPT_STR,
|
||||||
|
nickname: str, username: OPT_STR, realname: OPT_STR):
|
||||||
|
self.id = id
|
||||||
|
self.alias = alias
|
||||||
|
self.hostname = hostname
|
||||||
|
self.port = port
|
||||||
|
self.tls = tls
|
||||||
|
self.ipv4 = ipv4
|
||||||
|
self.bindhost = bindhost
|
||||||
|
self.password = password
|
||||||
|
self.nickname = nickname
|
||||||
|
self.username = username
|
||||||
|
self.realname = realname
|
||||||
|
|
Loading…
Add table
Reference in a new issue