Pass connection parameters around in their own object (IRCConnectionParameters)
This commit is contained in:
parent
e26df7556a
commit
6b8593a09b
5 changed files with 61 additions and 40 deletions
|
@ -38,7 +38,8 @@ class Module(ModuleManager.BaseModule):
|
|||
if not event["server"].tls:
|
||||
self._set_policy(event["server"], int(info["port"]),
|
||||
None, True)
|
||||
event["server"].disconnect()
|
||||
self.bot.reconnect(event["server"].id,
|
||||
event["server"].connection_params)
|
||||
else:
|
||||
self._change_duration(event["server"], info)
|
||||
|
||||
|
|
|
@ -12,18 +12,20 @@ class Servers(Table):
|
|||
username = username or nickname
|
||||
realname = realname or nickname
|
||||
self.database.execute(
|
||||
"""INSERT INTO servers (alias, hostname, port, password, ipv4,
|
||||
tls, bindhost, nickname, username, realname) VALUES (
|
||||
"""INSERT INTO servers (alias, hostname, port, password, tls,
|
||||
ipv4, bindhost, nickname, username, realname) VALUES (
|
||||
?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
[alias, hostname, port, password, ipv4, tls, bindhost, nickname,
|
||||
[alias, hostname, port, password, tls, ipv4, bindhost, nickname,
|
||||
username, realname])
|
||||
def get_all(self):
|
||||
return self.database.execute_fetchall(
|
||||
"SELECT server_id, alias FROM servers")
|
||||
def get(self, id: int):
|
||||
def get(self, id: int) -> typing.Tuple[int, typing.Optional[str], str,
|
||||
int, typing.Optional[str], bool, bool, typing.Optional[str], str,
|
||||
typing.Optional[str], typing.Optional[str]]:
|
||||
return self.database.execute_fetchone(
|
||||
"""SELECT server_id, alias, hostname, port, password, ipv4,
|
||||
tls, bindhost, nickname, username, realname FROM servers WHERE
|
||||
"""SELECT server_id, alias, hostname, port, password, tls,
|
||||
ipv4, bindhost, nickname, username, realname FROM servers WHERE
|
||||
server_id=?""",
|
||||
[id])
|
||||
|
||||
|
|
|
@ -36,14 +36,15 @@ class Bot(object):
|
|||
self._trigger_client.send(b"TRIGGER")
|
||||
self.lock.release()
|
||||
|
||||
def add_server(self, server_id: int, connect: bool = True
|
||||
) -> IRCServer.Server:
|
||||
(_, alias, hostname, port, password, ipv4, tls, bindhost, nickname,
|
||||
username, realname) = self.database.servers.get(server_id)
|
||||
def add_server(self, server_id: int, connect: bool = True,
|
||||
connection_params: typing.Optional[
|
||||
utils.irc.IRCConnectionParameters]=None) -> IRCServer.Server:
|
||||
if not connection_params:
|
||||
connection_params = utils.irc.IRCConnectionParameters(
|
||||
*self.database.servers.get(server_id))
|
||||
|
||||
new_server = IRCServer.Server(self, self._events, server_id, alias,
|
||||
hostname, port, password, ipv4, tls, bindhost, nickname, username,
|
||||
realname)
|
||||
new_server = IRCServer.Server(self, self._events,
|
||||
connection_params.id, connection_params)
|
||||
self._events.on("new.server").call(server=new_server)
|
||||
|
||||
if not connect or not new_server.get_setting("connect", True):
|
||||
|
@ -129,10 +130,15 @@ class Bot(object):
|
|||
del self.servers[server.fileno()]
|
||||
|
||||
def _timed_reconnect(self, event: EventManager.Event):
|
||||
if not self.reconnect(event["server_id"]):
|
||||
if not self.reconnect(event["server_id"], event["connection_params"]):
|
||||
event["timer"].redo()
|
||||
def reconnect(self, server_id: int) -> bool:
|
||||
server = self.add_server(server_id, False)
|
||||
def reconnect(self, server_id: int, connection_params: typing.Optional[
|
||||
utils.irc.IRCConnectionParameters]=None) -> bool:
|
||||
old_server = self.get_server(server_id)
|
||||
if old_server:
|
||||
self.disconnect(old_server)
|
||||
|
||||
server = self.add_server(server_id, False, connection_params)
|
||||
if self.connect(server):
|
||||
self.servers[server.fileno()] = server
|
||||
return True
|
||||
|
@ -204,7 +210,8 @@ class Bot(object):
|
|||
|
||||
reconnect_delay = self.config.get("reconnect-delay", 10)
|
||||
self._timers.add("reconnect", reconnect_delay,
|
||||
server_id=server.id)
|
||||
server_id=server.id,
|
||||
connection_params=server.connection_params)
|
||||
|
||||
print("disconnected from %s, reconnecting in %d seconds" % (
|
||||
str(server), reconnect_delay))
|
||||
|
|
|
@ -10,23 +10,13 @@ class Server(IRCObject.Object):
|
|||
def __init__(self,
|
||||
bot: "IRCBot.Bot",
|
||||
events: EventManager.EventHook,
|
||||
id: int, alias: str, hostname: str, port: int, password: str,
|
||||
ipv4: bool, tls: bool, bindhost: str,
|
||||
nickname: str, username: str, realname: str):
|
||||
id: int,
|
||||
connection_params: utils.irc.IRCConnectionParameters):
|
||||
self.connected = False
|
||||
self.bot = bot
|
||||
self.events = events
|
||||
self.id = id
|
||||
self.alias = alias
|
||||
self.target_hostname = hostname
|
||||
self.port = port
|
||||
self.tls = tls
|
||||
self.password = password
|
||||
self.ipv4 = ipv4
|
||||
self.bindhost = bindhost
|
||||
self.original_nickname = nickname
|
||||
self.original_username = username or nickname
|
||||
self.original_realname = realname or nickname
|
||||
self.connection_params = connection_params
|
||||
self.name = None # type: typing.Optional[str]
|
||||
|
||||
self._capability_queue = set([]) # type: typing.Set[str]
|
||||
|
@ -67,8 +57,8 @@ class Server(IRCObject.Object):
|
|||
def __str__(self):
|
||||
if self.alias:
|
||||
return self.alias
|
||||
return "%s:%s%s" % (self.target_hostname, "+" if self.tls else "",
|
||||
self.port)
|
||||
return "%s:%s%s" % (self.connection_params.hostname,
|
||||
"+" if self.tls else "", self.port)
|
||||
def fileno(self):
|
||||
return self.cached_fileno or self.socket.fileno()
|
||||
|
||||
|
@ -90,22 +80,27 @@ class Server(IRCObject.Object):
|
|||
self.socket = context.wrap_socket(self.socket)
|
||||
|
||||
def connect(self):
|
||||
family = socket.AF_INET if self.ipv4 else socket.AF_INET6
|
||||
ipv4 = self.connection_params.ipv4
|
||||
family = socket.AF_INET if ipv4 else socket.AF_INET6
|
||||
self.socket = socket.socket(family, socket.SOCK_STREAM)
|
||||
|
||||
self.socket.settimeout(5.0)
|
||||
if self.bindhost:
|
||||
|
||||
if self.connection_params.bindhost:
|
||||
self.socket.bind((self.bindhost, 0))
|
||||
if self.tls:
|
||||
if self.connection_params.tls:
|
||||
self.tls_wrap()
|
||||
|
||||
self.socket.connect((self.target_hostname, self.port))
|
||||
self.socket.connect((self.connection_params.hostname,
|
||||
self.connection_params.port))
|
||||
self.send_capibility_ls()
|
||||
|
||||
if self.password:
|
||||
self.send_pass(self.password)
|
||||
self.send_pass(self.connection_params.password)
|
||||
|
||||
self.send_user(self.original_username, self.original_realname)
|
||||
self.send_nick(self.original_nickname)
|
||||
self.send_user(self.connection_params.username,
|
||||
self.connection_params.realname)
|
||||
self.send_nick(self.connection_params.nickname)
|
||||
self.connected = True
|
||||
def disconnect(self):
|
||||
self.cached_fileno = self.socket.fileno()
|
||||
|
|
|
@ -142,3 +142,19 @@ def strip_font(s: str) -> str:
|
|||
s = s.replace(FONT_COLOR, "")
|
||||
return s
|
||||
|
||||
OPT_STR = typing.Optional[str]
|
||||
class IRCConnectionParameters(object):
|
||||
def __init__(self, id: int, alias: OPT_STR, hostname: str, port: int,
|
||||
tls: bool, ipv4: bool, password: OPT_STR, bindhost: OPT_STR,
|
||||
nickname: str, username: OPT_STR, realname: OPT_STR):
|
||||
self.id = id
|
||||
self.alias = alias
|
||||
self.hostname = hostname
|
||||
self.port = port
|
||||
self.tls = tls
|
||||
self.ipv4 = ipv4
|
||||
self.bindhost = bindhost
|
||||
self.password = password
|
||||
self.nickname = nickname
|
||||
self.username = username
|
||||
self.realname = realname
|
||||
|
|
Loading…
Add table
Reference in a new issue