From dc102f258d75493b11735fdd69e25f92be186c1f Mon Sep 17 00:00:00 2001 From: jesopo Date: Wed, 24 Apr 2019 12:07:30 +0100 Subject: [PATCH] Add Database.servers.get_by_alias, move IRCBot.get_server to IRCBot.get_server_by_id, add IRCBot.get_server_by_alias and change !connect/!disconnect to take aliases instead of IDs --- modules/admin.py | 28 ++++++++++++---------------- modules/channel_op.py | 2 +- modules/github/__init__.py | 2 +- modules/in.py | 2 +- modules/stats.py | 4 ++-- src/Database.py | 9 ++++++++- src/IRCBot.py | 10 ++++++++-- 7 files changed, 33 insertions(+), 24 deletions(-) diff --git a/modules/admin.py b/modules/admin.py index c9096e54..8d72a12b 100644 --- a/modules/admin.py +++ b/modules/admin.py @@ -52,15 +52,12 @@ class Module(ModuleManager.BaseModule): :usage: :permission: connect """ - id = event["args_split"][0] - if not id.isdigit(): - raise utils.EventError("Please provide a numeric server ID") + alias = event["args"] + id = self.bot.database.servers.get_by_alias(alias) + if id == None: + raise utils.EventError("Unknown server alias") - id = int(id) - if not self.bot.database.servers.get(id): - raise utils.EventError("Unknown server ID") - - existing_server = self.bot.get_server(id) + existing_server = self.bot.get_server_by_id(id) if existing_server: raise utils.EventError("Already connected to %s" % str( existing_server)) @@ -77,14 +74,13 @@ class Module(ModuleManager.BaseModule): """ id = event["server"].id if event["args"]: - id = event["args_split"][0] - if not id.isdigit(): - raise utils.EventError("Please provide a numeric server ID") - - id = int(id) - if not self.bot.database.servers.get(id): - raise utils.EventError("Unknown server ID") - server = self.bot.get_server(id) + print("alias") + alias = event["args"] + id = self.bot.database.servers.get_by_alias(alias) + if id == None: + raise utils.EventError("Unknown server alias") + print(id) + server = self.bot.get_server_by_id(id) server.disconnect() self.bot.disconnect(server) diff --git a/modules/channel_op.py b/modules/channel_op.py index 066cee4c..2f0a0620 100644 --- a/modules/channel_op.py +++ b/modules/channel_op.py @@ -21,7 +21,7 @@ class Module(ModuleManager.BaseModule): @utils.hook("timer.unban") def _timer_unban(self, event): - server = self.bot.get_server(event["server_id"]) + server = self.bot.get_server_by_id(event["server_id"]) if event["channel_name"] in server.channels: channel = server.channels.get(event["channel_name"]) channel.send_unban(event["hostmask"]) diff --git a/modules/github/__init__.py b/modules/github/__init__.py index a143ff7c..ff85b8f8 100644 --- a/modules/github/__init__.py +++ b/modules/github/__init__.py @@ -340,7 +340,7 @@ class Module(ModuleManager.BaseModule): if found_hook: repo_hooked = True - server = self.bot.get_server(server_id) + server = self.bot.get_server_by_id(server_id) if server and channel_name in server.channels: if (branch and found_hook["branches"] and diff --git a/modules/in.py b/modules/in.py index 890dc680..f1eda122 100644 --- a/modules/in.py +++ b/modules/in.py @@ -31,7 +31,7 @@ class Module(ModuleManager.BaseModule): @utils.hook("timer.in") def timer_due(self, event): - server = self.bot.get_server(event["server_id"]) + server = self.bot.get_server_by_id(event["server_id"]) if server: message = "%s: this is your reminder: %s" % ( event["nickname"], event["message"]) diff --git a/modules/stats.py b/modules/stats.py index 742574f4..6b19f061 100644 --- a/modules/stats.py +++ b/modules/stats.py @@ -73,7 +73,7 @@ class Module(ModuleManager.BaseModule): return None server_id = int(server_id) - server = self.bot.get_server(server_id) + server = self.bot.get_server_by_id(server_id) if not server: return None return self._server_stats(server) @@ -99,7 +99,7 @@ class Module(ModuleManager.BaseModule): return None server_id = int(server_id) - server = self.bot.get_server(server_id) + server = self.bot.get_server_by_id(server_id) if not server: return None channels = {} diff --git a/src/Database.py b/src/Database.py index f1563ac0..67fe4a6f 100644 --- a/src/Database.py +++ b/src/Database.py @@ -24,7 +24,7 @@ class Servers(Table): def get_all(self): return self.database.execute_fetchall( "SELECT server_id, alias FROM servers") - def get(self, id: int) -> typing.Tuple[int, typing.Optional[str], str, + def get(self, id: int)-> typing.Tuple[int, typing.Optional[str], str, int, typing.Optional[str], bool, bool, typing.Optional[str], str, typing.Optional[str], typing.Optional[str]]: return self.database.execute_fetchone( @@ -32,6 +32,13 @@ class Servers(Table): ipv4, bindhost, nickname, username, realname FROM servers WHERE server_id=?""", [id]) + def get_by_alias(self, alias: str) -> typing.Optional[int]: + value = self.database.execute_fetchone( + "SELECT server_id FROM servers WHERE alias=? COLLATE NOCASE", + [alias]) + if value: + return value[0] + return value class Channels(Table): def add(self, server_id: int, name: str): diff --git a/src/IRCBot.py b/src/IRCBot.py index d5beb766..0abfb432 100644 --- a/src/IRCBot.py +++ b/src/IRCBot.py @@ -84,11 +84,17 @@ class Bot(object): del self.other_sockets[sock.fileno()] self.poll.unregister(sock.fileno()) - def get_server(self, id: int) -> typing.Optional[IRCServer.Server]: + def get_server_by_id(self, id: int) -> typing.Optional[IRCServer.Server]: for server in self.servers.values(): if server.id == id: return server return None + def get_server_by_alias(self, alias: str) -> typing.Optional[IRCServer.Server]: + alias_lower = alias.lower() + for server in self.servers.values(): + if server.alias.lower() == alias_lower: + return server + return None def connect(self, server: IRCServer.Server) -> bool: try: @@ -235,7 +241,7 @@ class Bot(object): self._events.on("server.disconnect").call(server=server) self.disconnect(server) - if not self.get_server(server.id): + if not self.get_server_by_id(server.id): reconnect_delay = self.config.get("reconnect-delay", 10) self._timers.add("reconnect", reconnect_delay, server_id=server.id)