Add Database.servers.get_by_alias, move IRCBot.get_server to

IRCBot.get_server_by_id, add IRCBot.get_server_by_alias and change
!connect/!disconnect to take aliases instead of IDs
This commit is contained in:
jesopo 2019-04-24 12:07:30 +01:00
parent bd4fdfdc7b
commit dc102f258d
7 changed files with 33 additions and 24 deletions

View file

@ -52,15 +52,12 @@ class Module(ModuleManager.BaseModule):
:usage: <server id> :usage: <server id>
:permission: connect :permission: connect
""" """
id = event["args_split"][0] alias = event["args"]
if not id.isdigit(): id = self.bot.database.servers.get_by_alias(alias)
raise utils.EventError("Please provide a numeric server ID") if id == None:
raise utils.EventError("Unknown server alias")
id = int(id) existing_server = self.bot.get_server_by_id(id)
if not self.bot.database.servers.get(id):
raise utils.EventError("Unknown server ID")
existing_server = self.bot.get_server(id)
if existing_server: if existing_server:
raise utils.EventError("Already connected to %s" % str( raise utils.EventError("Already connected to %s" % str(
existing_server)) existing_server))
@ -77,14 +74,13 @@ class Module(ModuleManager.BaseModule):
""" """
id = event["server"].id id = event["server"].id
if event["args"]: if event["args"]:
id = event["args_split"][0] print("alias")
if not id.isdigit(): alias = event["args"]
raise utils.EventError("Please provide a numeric server ID") id = self.bot.database.servers.get_by_alias(alias)
if id == None:
id = int(id) raise utils.EventError("Unknown server alias")
if not self.bot.database.servers.get(id): print(id)
raise utils.EventError("Unknown server ID") server = self.bot.get_server_by_id(id)
server = self.bot.get_server(id)
server.disconnect() server.disconnect()
self.bot.disconnect(server) self.bot.disconnect(server)

View file

@ -21,7 +21,7 @@ class Module(ModuleManager.BaseModule):
@utils.hook("timer.unban") @utils.hook("timer.unban")
def _timer_unban(self, event): def _timer_unban(self, event):
server = self.bot.get_server(event["server_id"]) server = self.bot.get_server_by_id(event["server_id"])
if event["channel_name"] in server.channels: if event["channel_name"] in server.channels:
channel = server.channels.get(event["channel_name"]) channel = server.channels.get(event["channel_name"])
channel.send_unban(event["hostmask"]) channel.send_unban(event["hostmask"])

View file

@ -340,7 +340,7 @@ class Module(ModuleManager.BaseModule):
if found_hook: if found_hook:
repo_hooked = True repo_hooked = True
server = self.bot.get_server(server_id) server = self.bot.get_server_by_id(server_id)
if server and channel_name in server.channels: if server and channel_name in server.channels:
if (branch and if (branch and
found_hook["branches"] and found_hook["branches"] and

View file

@ -31,7 +31,7 @@ class Module(ModuleManager.BaseModule):
@utils.hook("timer.in") @utils.hook("timer.in")
def timer_due(self, event): def timer_due(self, event):
server = self.bot.get_server(event["server_id"]) server = self.bot.get_server_by_id(event["server_id"])
if server: if server:
message = "%s: this is your reminder: %s" % ( message = "%s: this is your reminder: %s" % (
event["nickname"], event["message"]) event["nickname"], event["message"])

View file

@ -73,7 +73,7 @@ class Module(ModuleManager.BaseModule):
return None return None
server_id = int(server_id) server_id = int(server_id)
server = self.bot.get_server(server_id) server = self.bot.get_server_by_id(server_id)
if not server: if not server:
return None return None
return self._server_stats(server) return self._server_stats(server)
@ -99,7 +99,7 @@ class Module(ModuleManager.BaseModule):
return None return None
server_id = int(server_id) server_id = int(server_id)
server = self.bot.get_server(server_id) server = self.bot.get_server_by_id(server_id)
if not server: if not server:
return None return None
channels = {} channels = {}

View file

@ -24,7 +24,7 @@ class Servers(Table):
def get_all(self): def get_all(self):
return self.database.execute_fetchall( return self.database.execute_fetchall(
"SELECT server_id, alias FROM servers") "SELECT server_id, alias FROM servers")
def get(self, id: int) -> typing.Tuple[int, typing.Optional[str], str, def get(self, id: int)-> typing.Tuple[int, typing.Optional[str], str,
int, typing.Optional[str], bool, bool, typing.Optional[str], str, int, typing.Optional[str], bool, bool, typing.Optional[str], str,
typing.Optional[str], typing.Optional[str]]: typing.Optional[str], typing.Optional[str]]:
return self.database.execute_fetchone( return self.database.execute_fetchone(
@ -32,6 +32,13 @@ class Servers(Table):
ipv4, bindhost, nickname, username, realname FROM servers WHERE ipv4, bindhost, nickname, username, realname FROM servers WHERE
server_id=?""", server_id=?""",
[id]) [id])
def get_by_alias(self, alias: str) -> typing.Optional[int]:
value = self.database.execute_fetchone(
"SELECT server_id FROM servers WHERE alias=? COLLATE NOCASE",
[alias])
if value:
return value[0]
return value
class Channels(Table): class Channels(Table):
def add(self, server_id: int, name: str): def add(self, server_id: int, name: str):

View file

@ -84,11 +84,17 @@ class Bot(object):
del self.other_sockets[sock.fileno()] del self.other_sockets[sock.fileno()]
self.poll.unregister(sock.fileno()) self.poll.unregister(sock.fileno())
def get_server(self, id: int) -> typing.Optional[IRCServer.Server]: def get_server_by_id(self, id: int) -> typing.Optional[IRCServer.Server]:
for server in self.servers.values(): for server in self.servers.values():
if server.id == id: if server.id == id:
return server return server
return None return None
def get_server_by_alias(self, alias: str) -> typing.Optional[IRCServer.Server]:
alias_lower = alias.lower()
for server in self.servers.values():
if server.alias.lower() == alias_lower:
return server
return None
def connect(self, server: IRCServer.Server) -> bool: def connect(self, server: IRCServer.Server) -> bool:
try: try:
@ -235,7 +241,7 @@ class Bot(object):
self._events.on("server.disconnect").call(server=server) self._events.on("server.disconnect").call(server=server)
self.disconnect(server) self.disconnect(server)
if not self.get_server(server.id): if not self.get_server_by_id(server.id):
reconnect_delay = self.config.get("reconnect-delay", 10) reconnect_delay = self.config.get("reconnect-delay", 10)
self._timers.add("reconnect", reconnect_delay, self._timers.add("reconnect", reconnect_delay,
server_id=server.id) server_id=server.id)