From e07553c3627b80f20cdc81a35030bf0540924db8 Mon Sep 17 00:00:00 2001 From: jesopo Date: Tue, 30 Oct 2018 14:58:48 +0000 Subject: [PATCH] Add type/return hints throughout src/ and, in doing so, fix some cyclical references. --- modules/karma.py | 2 +- modules/line_handler.py | 3 +- modules/sed.py | 13 ++- src/Cache.py | 18 ++-- src/Config.py | 10 +- src/Database.py | 101 +++++++++++--------- src/EventManager.py | 207 +++++++++++++++++++++------------------- src/Exports.py | 34 +++---- src/IRCBot.py | 46 +++++---- src/IRCBuffer.py | 29 +++--- src/IRCChannel.py | 83 +++++++++------- src/IRCServer.py | 137 +++++++++++++------------- src/IRCUser.py | 44 +++++---- src/Logging.py | 18 ++-- src/ModuleManager.py | 49 ++++++---- src/Socket.py | 26 ++--- src/Timers.py | 62 ++++++------ src/utils/__init__.py | 123 +++++++----------------- src/utils/consts.py | 2 + src/utils/http.py | 11 ++- src/utils/irc.py | 46 +++++---- src/utils/parse.py | 57 +++++++++++ 22 files changed, 605 insertions(+), 516 deletions(-) create mode 100644 src/utils/consts.py create mode 100644 src/utils/parse.py diff --git a/modules/karma.py b/modules/karma.py index 441b061e..d3bab796 100644 --- a/modules/karma.py +++ b/modules/karma.py @@ -27,7 +27,7 @@ class Module(ModuleManager.BaseModule): if not event["user"].last_karma or (time.time()-event["user" ].last_karma) >= KARMA_DELAY_SECONDS: target = match.group(1).strip() - if utils.irc.lower(event["server"], target + if utils.irc.lower(event["server"].case_mapping, target ) == event["user"].name: if verbose: self.events.on("send.stderr").call( diff --git a/modules/line_handler.py b/modules/line_handler.py index da0d5d07..3ee8304b 100644 --- a/modules/line_handler.py +++ b/modules/line_handler.py @@ -542,7 +542,8 @@ class Module(ModuleManager.BaseModule): # we need a registered nickname for this channel @utils.hook("raw.477", default_event=True) def handle_477(self, event): - channel_name = utils.irc.lower(event["server"], event["args"][1]) + channel_name = utils.irc.lower(event["server"].case_mapping, + event["args"][1]) if channel_name in event["server"]: key = event["server"].attempted_join[channel_name] self.timers.add("rejoin", 5, channel_name=channe_name, key=key, diff --git a/modules/sed.py b/modules/sed.py index 6de7b4af..7b98ab70 100644 --- a/modules/sed.py +++ b/modules/sed.py @@ -11,12 +11,16 @@ REGEX_SED = re.compile("^s/") "help": "Disable/Enable sed only looking at the messages sent by the user", "validate": utils.bool_or_none}) class Module(ModuleManager.BaseModule): + def _closest_setting(self, event, setting, default): + return event["channel"].get_setting(setting, + event["server"].get_setting(setting, default)) + @utils.hook("received.message.channel") def channel_message(self, event): sed_split = re.split(REGEX_SPLIT, event["message"], 3) if event["message"].startswith("s/") and len(sed_split) > 2: - if event["action"] or not utils.get_closest_setting( - event, "sed", False): + if event["action"] or not self._closest_setting(event, "sed", + False): return regex_flags = 0 @@ -48,9 +52,8 @@ class Module(ModuleManager.BaseModule): return replace = sed_split[2].replace("\\/", "/") - for_user = event["user"].nickname if utils.get_closest_setting( - event, "sed-sender-only", False - ) else None + for_user = event["user"].nickname if self._closest_setting(event, + "sed-sender-only", False) else None line = event["channel"].buffer.find(pattern, from_self=False, for_user=for_user, not_pattern=REGEX_SED) if line: diff --git a/src/Cache.py b/src/Cache.py index 2de55afa..46b39bda 100644 --- a/src/Cache.py +++ b/src/Cache.py @@ -1,21 +1,21 @@ -import time, uuid +import time, typing, uuid class Cache(object): def __init__(self): self._items = {} self._item_to_id = {} - def cache(self, item): + def cache(self, item: typing.Any) -> str: return self._cache(item, None) - def temporary_cache(self, item, timeout): + def temporary_cache(self, item: typing.Any, timeout: float)-> str: return self._cache(item, timeout) - def _cache(self, item, timeout): + def _cache(self, item: typing.Any, timeout: float) -> str: id = str(uuid.uuid4()) self._items[id] = [item, time.monotonic()+timeout] self._item_to_id[item] = id return id - def next_expiration(self): + def next_expiration(self) -> float: expirations = [self._items[id][1] for id in self._items] expirations = list(filter(None, expirations)) if not expirations: @@ -35,17 +35,17 @@ class Cache(object): del self._items[id] del self._item_to_id[item] - def has_item(self, item): + def has_item(self, item: typing.Any) -> bool: return item in self._item_to_id - def get(self, id): + def get(self, id: str) -> typing.Any: item, expiration = self._items[id] return item - def get_expiration(self, item): + def get_expiration(self, item: typing.Any) -> float: id = self._item_to_id[item] item, expiration = self._items[id] return expiration - def until_expiration(self, item): + def until_expiration(self, item: typing.Any) -> float: expiration = self.get_expiration(item) return expiration-time.monotonic() diff --git a/src/Config.py b/src/Config.py index 611b5b7b..dacb14dd 100644 --- a/src/Config.py +++ b/src/Config.py @@ -1,7 +1,7 @@ -import configparser, os +import configparser, os, typing class Config(object): - def __init__(self, location): + def __init__(self, location: str): self.location = location self._config = {} self.load() @@ -13,10 +13,10 @@ class Config(object): parser.read_string(config_file.read()) self._config = dict(parser["bot"].items()) - def __getitem__(self, key): + def __getitem__(self, key: str) -> typing.Any: return self._config[key] - def get(self, key, default=None): + def get(self, key: str, default: typing.Any=None) -> typing.Any: return self._config.get(key, default) - def __contains__(self, key): + def __contains__(self, key: str) -> bool: return key in self._config diff --git a/src/Database.py b/src/Database.py index c1d07869..c3d48cb6 100644 --- a/src/Database.py +++ b/src/Database.py @@ -1,12 +1,14 @@ -import json, os, sqlite3, threading, time +import json, os, sqlite3, threading, time, typing +from src import Logging class Table(object): def __init__(self, database): self.database = database class Servers(Table): - def add(self, alias, hostname, port, password, ipv4, tls, bindhost, - nickname, username=None, realname=None): + def add(self, alias: str, hostname: str, port: int, password: str, + ipv4: bool, tls: bool, bindhost: str, + nickname: str, username: str=None, realname: str=None): username = username or nickname realname = realname or nickname self.database.execute( @@ -18,7 +20,7 @@ class Servers(Table): def get_all(self): return self.database.execute_fetchall( "SELECT server_id, alias FROM servers") - def get(self, id): + def get(self, id: int): return self.database.execute_fetchone( """SELECT server_id, alias, hostname, port, password, ipv4, tls, bindhost, nickname, username, realname FROM servers WHERE @@ -26,46 +28,46 @@ class Servers(Table): [id]) class Channels(Table): - def add(self, server_id, name): + def add(self, server_id: int, name: str): self.database.execute("""INSERT OR IGNORE INTO channels (server_id, name) VALUES (?, ?)""", [server_id, name.lower()]) - def delete(self, channel_id): + def delete(self, channel_id: int): self.database.execute("DELETE FROM channels WHERE channel_id=?", [channel_id]) - def get_id(self, server_id, name): + def get_id(self, server_id: int, name: str): value = self.database.execute_fetchone("""SELECT channel_id FROM channels WHERE server_id=? AND name=?""", [server_id, name.lower()]) return value if value == None else value[0] class Users(Table): - def add(self, server_id, nickname): + def add(self, server_id: int, nickname: str): self.database.execute("""INSERT OR IGNORE INTO users (server_id, nickname) VALUES (?, ?)""", [server_id, nickname.lower()]) - def delete(self, user_id): + def delete(self, user_id: int): self.database.execute("DELETE FROM users WHERE user_id=?", [user_id]) - def get_id(self, server_id, nickname): + def get_id(self, server_id: int, nickname: str): value = self.database.execute_fetchone("""SELECT user_id FROM users WHERE server_id=? and nickname=?""", [server_id, nickname.lower()]) return value if value == None else value[0] class BotSettings(Table): - def set(self, setting, value): + def set(self, setting: str, value: typing.Any): self.database.execute( "INSERT OR REPLACE INTO bot_settings VALUES (?, ?)", [setting.lower(), json.dumps(value)]) - def get(self, setting, default=None): + def get(self, setting: str, default: typing.Any=None): value = self.database.execute_fetchone( "SELECT value FROM bot_settings WHERE setting=?", [setting.lower()]) if value: return json.loads(value[0]) return default - def find(self, pattern, default=[]): + def find(self, pattern: str, default: typing.Any=[]): values = self.database.execute_fetchall( "SELECT setting, value FROM bot_settings WHERE setting LIKE ?", [pattern.lower()]) @@ -74,19 +76,19 @@ class BotSettings(Table): values[i] = value[0], json.loads(value[1]) return values return default - def find_prefix(self, prefix, default=[]): + def find_prefix(self, prefix: str, default: typing.Any=[]): return self.find("%s%%" % prefix, default) - def delete(self, setting): + def delete(self, setting: str): self.database.execute( "DELETE FROM bot_settings WHERE setting=?", [setting.lower()]) class ServerSettings(Table): - def set(self, server_id, setting, value): + def set(self, server_id: int, setting: str, value: typing.Any): self.database.execute( "INSERT OR REPLACE INTO server_settings VALUES (?, ?, ?)", [server_id, setting.lower(), json.dumps(value)]) - def get(self, server_id, setting, default=None): + def get(self, server_id: int, setting: str, default: typing.Any=None): value = self.database.execute_fetchone( """SELECT value FROM server_settings WHERE server_id=? AND setting=?""", @@ -94,7 +96,7 @@ class ServerSettings(Table): if value: return json.loads(value[0]) return default - def find(self, server_id, pattern, default=[]): + def find(self, server_id: int, pattern: str, default: typing.Any=[]): values = self.database.execute_fetchall( """SELECT setting, value FROM server_settings WHERE server_id=? AND setting LIKE ?""", @@ -104,26 +106,26 @@ class ServerSettings(Table): values[i] = value[0], json.loads(value[1]) return values return default - def find_prefix(self, server_id, prefix, default=[]): + def find_prefix(self, server_id: int, prefix: str, default: typing.Any=[]): return self.find_server_settings(server_id, "%s%%" % prefix, default) - def delete(self, server_id, setting): + def delete(self, server_id: int, setting: str): self.database.execute( "DELETE FROM server_settings WHERE server_id=? AND setting=?", [server_id, setting.lower()]) class ChannelSettings(Table): - def set(self, channel_id, setting, value): + def set(self, channel_id: int, setting: str, value: typing.Any): self.database.execute( "INSERT OR REPLACE INTO channel_settings VALUES (?, ?, ?)", [channel_id, setting.lower(), json.dumps(value)]) - def get(self, channel_id, setting, default=None): + def get(self, channel_id: int, setting: str, default: typing.Any=None): value = self.database.execute_fetchone( """SELECT value FROM channel_settings WHERE channel_id=? AND setting=?""", [channel_id, setting.lower()]) if value: return json.loads(value[0]) return default - def find(self, channel_id, pattern, default=[]): + def find(self, channel_id: int, pattern: str, default: typing.Any=[]): values = self.database.execute_fetchall( """SELECT setting, value FROM channel_settings WHERE channel_id=? setting LIKE '?'""", [channel_id, pattern.lower()]) @@ -132,15 +134,15 @@ class ChannelSettings(Table): values[i] = value[0], json.loads(value[1]) return values return default - def find_prefix(self, channel_id, prefix, default=[]): + def find_prefix(self, channel_id: int, prefix: str, default: typing.Any=[]): return self.find_channel_settings(channel_id, "%s%%" % prefix, default) - def delete(self, channel_id, setting): + def delete(self, channel_id: int, setting: str): self.database.execute( """DELETE FROM channel_settings WHERE channel_id=? AND setting=?""", [channel_id, setting.lower()]) - def find_by_setting(self, setting, default=[]): + def find_by_setting(self, setting: str, default: typing.Any=[]): values = self.database.execute_fetchall( """SELECT channels.server_id, channels.name, channel_settings.value FROM channel_settings @@ -154,18 +156,19 @@ class ChannelSettings(Table): return default class UserSettings(Table): - def set(self, user_id, setting, value): + def set(self, user_id: int, setting: str, value: typing.Any): self.database.execute( "INSERT OR REPLACE INTO user_settings VALUES (?, ?, ?)", [user_id, setting.lower(), json.dumps(value)]) - def get(self, user_id, setting, default=None): + def get(self, user_id: int, setting: str, default: typing.Any=None): value = self.database.execute_fetchone( """SELECT value FROM user_settings WHERE user_id=? and setting=?""", [user_id, setting.lower()]) if value: return json.loads(value[0]) return default - def find_all_by_setting(self, server_id, setting, default=[]): + def find_all_by_setting(self, server_id: int, setting: str, + default: typing.Any=[]): values = self.database.execute_fetchall( """SELECT users.nickname, user_settings.value FROM user_settings INNER JOIN users ON @@ -177,7 +180,7 @@ class UserSettings(Table): values[i] = value[0], json.loads(value[1]) return values return default - def find(self, user_id, pattern, default=[]): + def find(self, user_id: int, pattern: str, default: typing.Any=[]): values = self.database.execute( """SELECT setting, value FROM user_settings WHERE user_id=? AND setting LIKE '?'""", [user_id, pattern.lower()]) @@ -186,20 +189,22 @@ class UserSettings(Table): values[i] = value[0], json.loads(value[1]) return values return default - def find_prefix(self, user_id, prefix, default=[]): + def find_prefix(self, user_id: int, prefix: str, default: typing.Any=[]): return self.find_user_settings(user_id, "%s%%" % prefix, default) - def delete(self, user_id, setting): + def delete(self, user_id: int, setting: str): self.database.execute( """DELETE FROM user_settings WHERE user_id=? AND setting=?""", [user_id, setting.lower()]) class UserChannelSettings(Table): - def set(self, user_id, channel_id, setting, value): + def set(self, user_id: int, channel_id: int, setting: str, + value: typing.Any): self.database.execute( """INSERT OR REPLACE INTO user_channel_settings VALUES (?, ?, ?, ?)""", [user_id, channel_id, setting.lower(), json.dumps(value)]) - def get(self, user_id, channel_id, setting, default=None): + def get(self, user_id: int, channel_id: int, setting: str, + default: typing.Any=None): value = self.database.execute_fetchone( """SELECT value FROM user_channel_settings WHERE user_id=? AND channel_id=? AND setting=?""", @@ -207,7 +212,8 @@ class UserChannelSettings(Table): if value: return json.loads(value[0]) return default - def find(self, user_id, channel_id, pattern, default=[]): + def find(self, user_id: int, channel_id: int, pattern: str, + default: typing.Any=[]): values = self.database.execute_fetchall( """SELECT setting, value FROM user_channel_settings WHERE user_id=? AND channel_id=? AND setting LIKE '?'""", @@ -217,10 +223,12 @@ class UserChannelSettings(Table): values[i] = value[0], json.loads(value[1]) return values return default - def find_prefix(self, user_id, channel_id, prefix, default=[]): + def find_prefix(self, user_id: int, channel_id: int, prefix: str, + default: typing.Any=[]): return self.find_user_settings(user_id, channel_id, "%s%%" % prefix, default) - def find_by_setting(self, user_id, setting, default=[]): + def find_by_setting(self, user_id: int, setting: str, + default: typing.Any=[]): values = self.database.execute_fetchall( """SELECT channels.name, user_channel_settings.value FROM user_channel_settings INNER JOIN channels ON @@ -232,7 +240,8 @@ class UserChannelSettings(Table): values[i] = value[0], json.loads(value[1]) return values return default - def find_all_by_setting(self, server_id, setting, default=[]): + def find_all_by_setting(self, server_id: int, setting: str, + default: typing.Any=[]): values = self.database.execute_fetchall( """SELECT channels.name, users.nickname, user_channel_settings.value FROM @@ -246,14 +255,14 @@ class UserChannelSettings(Table): values[i] = value[0], value[1], json.loads(value[2]) return values return default - def delete(self, user_id, channel_id, setting): + def delete(self, user_id: int, channel_id: int, setting: str): self.database.execute( """DELETE FROM user_channel_settings WHERE user_id=? AND channel_id=? AND setting=?""", [user_id, channel_id, setting.lower()]) class Database(object): - def __init__(self, log, location): + def __init__(self, log: "Logging.Log", location: str): self.log = log self.location = location self.database = sqlite3.connect(self.location, @@ -284,7 +293,9 @@ class Database(object): self._cursor = self.database.cursor() return self._cursor - def _execute_fetch(self, query, fetch_func, params=[]): + def _execute_fetch(self, query: str, + fetch_func: typing.Callable[[sqlite3.Cursor], typing.Any], + params: typing.List=[]): printable_query = " ".join(query.split()) self.log.trace("executing query: \"%s\" (params: %s)", [printable_query, params]) @@ -299,16 +310,16 @@ class Database(object): self.log.trace("executed in %fms", [total_milliseconds]) return value - def execute_fetchall(self, query, params=[]): + def execute_fetchall(self, query: str, params: typing.List=[]): return self._execute_fetch(query, lambda cursor: cursor.fetchall(), params) - def execute_fetchone(self, query, params=[]): + def execute_fetchone(self, query: str, params: typing.List=[]): return self._execute_fetch(query, lambda cursor: cursor.fetchone(), params) - def execute(self, query, params=[]): + def execute(self, query: str, params: typing.List=[]): return self._execute_fetch(query, lambda cursor: None, params) - def has_table(self, table_name): + def has_table(self, table_name: str): result = self.execute_fetchone("""SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?""", [table_name]) diff --git a/src/EventManager.py b/src/EventManager.py index 15115ad8..bdaa3b7a 100644 --- a/src/EventManager.py +++ b/src/EventManager.py @@ -1,5 +1,5 @@ -import itertools, time, traceback -from src import utils +import itertools, time, traceback, typing +from src import Logging, utils PRIORITY_URGENT = 0 PRIORITY_HIGH = 1 @@ -11,94 +11,39 @@ DEFAULT_PRIORITY = PRIORITY_MEDIUM DEFAULT_EVENT_DELIMITER = "." DEFAULT_MULTI_DELIMITER = "|" +CALLBACK_TYPE = typing.Callable[["Event"], typing.Any] + class Event(object): - def __init__(self, name, **kwargs): + def __init__(self, name: str, **kwargs): self.name = name self.kwargs = kwargs self.eaten = False - def __getitem__(self, key): + def __getitem__(self, key: str) -> typing.Any: return self.kwargs[key] - def get(self, key, default=None): + def get(self, key: str, default=None) -> typing.Any: return self.kwargs.get(key, default) - def __contains__(self, key): + def __contains__(self, key: str) -> bool: return key in self.kwargs def eat(self): self.eaten = True class EventCallback(object): - def __init__(self, function, priority, kwargs): + def __init__(self, function: CALLBACK_TYPE, priority: int, kwargs: dict): self.function = function self.priority = priority self.kwargs = kwargs - self.docstring = utils.parse_docstring(function.__doc__) + self.docstring = utils.parse.docstring(function.__doc__) - def call(self, event): + def call(self, event: Event) -> typing.Any: return self.function(event) - def get_kwarg(self, name, default=None): + def get_kwarg(self, name: str, default=None) -> typing.Any: item = self.kwargs.get(name, default) return item or self.docstring.items.get(name, default) -class MultipleEventHook(object): - def __init__(self): - self._event_hooks = set([]) - def _add(self, event_hook): - self._event_hooks.add(event_hook) - - def hook(self, function, **kwargs): - for event_hook in self._event_hooks: - event_hook.hook(function, **kwargs) - - def call_limited(self, maximum, **kwargs): - returns = [] - for event_hook in self._event_hooks: - returns.append(event_hook.call_limited(maximum, **kwargs)) - return returns - def call(self, **kwargs): - returns = [] - for event_hook in self._event_hooks: - returns.append(event_hook.call(**kwargs)) - return returns - -class EventHookContext(object): - def __init__(self, parent, context): - self._parent = parent - self.context = context - def hook(self, function, priority=DEFAULT_PRIORITY, replay=False, - **kwargs): - return self._parent._context_hook(self.context, function, priority, - replay, kwargs) - def unhook(self, callback): - self._parent.unhook(callback) - - def on(self, subevent, *extra_subevents, - delimiter=DEFAULT_EVENT_DELIMITER): - return self._parent._context_on(self.context, subevent, - extra_subevents, delimiter) - - def call_for_result(self, default=None, **kwargs): - return self._parent.call_for_result(default, **kwargs) - def assure_call(self, **kwargs): - self._parent.assure_call(**kwargs) - def call(self, **kwargs): - return self._parent.call(**kwargs) - def call_limited(self, maximum, **kwargs): - return self._parent.call_limited(maximum, **kwargs) - - def call_unsafe_for_result(self, default=None, **kwargs): - return self._parent.call_unsafe_for_result(default, **kwargs) - def call_unsafe(self, **kwargs): - return self._parent.call_unsafe(**kwargs) - def call_unsafe_limited(self, maximum, **kwargs): - return self._parent.call_unsafe_limited(maximum, **kwargs) - - def get_hooks(self): - return self._parent.get_hooks() - def get_children(self): - return self._parent.get_children() - class EventHook(object): - def __init__(self, log, name=None, parent=None): + def __init__(self, log: Logging.Log, name: str = None, + parent: "EventHook" = None): self.log = log self.name = name self.parent = parent @@ -107,10 +52,10 @@ class EventHook(object): self._stored_events = [] self._context_hooks = {} - def _make_event(self, kwargs): + def _make_event(self, kwargs: dict) -> Event: return Event(self._get_path(), **kwargs) - def _get_path(self): + def _get_path(self) -> str: path = [] parent = self while not parent == None and not parent.name == None: @@ -118,15 +63,17 @@ class EventHook(object): parent = parent.parent return DEFAULT_EVENT_DELIMITER.join(path[::-1]) - def new_context(self, context): + def new_context(self, context: str) -> "EventHookContext": return EventHookContext(self, context) - def hook(self, function, priority=DEFAULT_PRIORITY, replay=False, - **kwargs): + def hook(self, function: CALLBACK_TYPE, priority: int = DEFAULT_PRIORITY, + replay: bool = False, **kwargs) -> EventCallback: return self._hook(function, None, priority, replay, kwargs) - def _context_hook(self, context, function, priority, replay, kwargs): + def _context_hook(self, context: str, function: CALLBACK_TYPE, + priority: int, replay: bool, kwargs: dict) -> EventCallback: return self._hook(function, context, priority, replay, kwargs) - def _hook(self, function, context, priority, replay, kwargs): + def _hook(self, function: CALLBACK_TYPE, context: str, priority: int, + replay: bool, kwargs: dict) -> EventCallback: callback = EventCallback(function, priority, kwargs) if context == None: @@ -142,7 +89,7 @@ class EventHook(object): self._stored_events = None return callback - def unhook(self, callback): + def unhook(self, callback: "EventHook"): if callback in self._hooks: self._hooks.remove(callback) @@ -155,7 +102,8 @@ class EventHook(object): for context in empty: del self._context_hooks[context] - def _make_multiple_hook(self, source, context, events): + def _make_multiple_hook(self, source: "EventHook", context: str, + events: typing.List[str]) -> "MultipleEventHook": multiple_event_hook = MultipleEventHook() for event in events: event_hook = source.get_child(event) @@ -164,13 +112,15 @@ class EventHook(object): multiple_event_hook._add(event_hook) return multiple_event_hook - def on(self, subevent, *extra_subevents, - delimiter=DEFAULT_EVENT_DELIMITER): + def on(self, subevent: str, *extra_subevents, + delimiter: int = DEFAULT_EVENT_DELIMITER) -> "EventHook": return self._on(subevent, extra_subevents, None, delimiter) - def _context_on(self, context, subevent, extra_subevents, - delimiter=DEFAULT_EVENT_DELIMITER): + def _context_on(self, context: str, subevent: str, + extra_subevents: typing.List[str], + delimiter: str = DEFAULT_EVENT_DELIMITER) -> "EventHook": return self._on(subevent, extra_subevents, context, delimiter) - def _on(self, subevent, extra_subevents, context, delimiter): + def _on(self, subevent: str, extra_subevents: typing.List[str], + context: str, delimiter: str) -> "EventHook": if delimiter in subevent: event_chain = subevent.split(delimiter) event_obj = self @@ -193,26 +143,28 @@ class EventHook(object): child = child.new_context(context) return child - def call_for_result(self, default=None, **kwargs): + def call_for_result(self, default=None, **kwargs) -> typing.Any: return (self.call_limited(1, **kwargs) or [default])[0] def assure_call(self, **kwargs): if not self._stored_events == None: self._stored_events.append(kwargs) else: self._call(kwargs, True, None) - def call(self, **kwargs): + def call(self, **kwargs) -> typing.List[typing.Any]: return self._call(kwargs, True, None) - def call_limited(self, maximum, **kwargs): + def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]: return self._call(kwargs, True, None) - def call_unsafe_for_result(self, default=None, **kwargs): + def call_unsafe_for_result(self, default=None, **kwargs) -> typing.Any: return (self.call_unsafe_limited(1, **kwargs) or [default])[0] - def call_unsafe(self, **kwargs): + def call_unsafe(self, **kwargs) -> typing.List[typing.Any]: return self._call(kwargs, False, None) - def call_unsafe_limited(self, maximum, **kwargs): + def call_unsafe_limited(self, maximum: int, **kwargs + ) -> typing.List[typing.Any]: return self._call(kwargs, False, maximum) - def _call(self, kwargs, safe, maximum): + def _call(self, kwargs: dict, safe: bool, maximum: int + ) -> typing.List[typing.Any]: event_path = self._get_path() self.log.trace("calling event: \"%s\" (params: %s)", [event_path, kwargs]) @@ -240,13 +192,13 @@ class EventHook(object): return returns - def get_child(self, child_name): + def get_child(self, child_name: str) -> "EventHook": child_name_lower = child_name.lower() if not child_name_lower in self._children: self._children[child_name_lower] = EventHook(self.log, child_name_lower, self) return self._children[child_name_lower] - def remove_child(self, child_name): + def remove_child(self, child_name: str): child_name_lower = child_name.lower() if child_name_lower in self._children: del self._children[child_name_lower] @@ -256,11 +208,11 @@ class EventHook(object): self.parent.remove_child(self.name) self.parent.check_purge() - def remove_context(self, context): + def remove_context(self, context: str): del self._context_hooks[context] - def has_context(self, context): + def has_context(self, context: str) -> bool: return context in self._context_hooks - def purge_context(self, context): + def purge_context(self, context: str): if self.has_context(context): self.remove_context(context) @@ -268,10 +220,69 @@ class EventHook(object): child = self.get_child(child_name) child.purge_context(context) - def get_hooks(self): + def get_hooks(self) -> typing.List[EventCallback]: return sorted(self._hooks + sum(self._context_hooks.values(), []), key=lambda e: e.priority) - def get_children(self): + def get_children(self) -> typing.List["EventHook"]: return list(self._children.keys()) - def is_empty(self): + def is_empty(self) -> bool: return len(self.get_hooks() + self.get_children()) == 0 + +class MultipleEventHook(object): + def __init__(self): + self._event_hooks = set([]) + def _add(self, event_hook: EventHook): + self._event_hooks.add(event_hook) + + def hook(self, function: CALLBACK_TYPE, **kwargs): + for event_hook in self._event_hooks: + event_hook.hook(function, **kwargs) + + def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]: + returns = [] + for event_hook in self._event_hooks: + returns.append(event_hook.call_limited(maximum, **kwargs)) + return returns + def call(self, **kwargs) -> typing.List[typing.Any]: + returns = [] + for event_hook in self._event_hooks: + returns.append(event_hook.call(**kwargs)) + return returns + +class EventHookContext(object): + def __init__(self, parent, context): + self._parent = parent + self.context = context + def hook(self, function: CALLBACK_TYPE, priority: int = DEFAULT_PRIORITY, + replay: bool = False, **kwargs) -> EventCallback: + return self._parent._context_hook(self.context, function, priority, + replay, kwargs) + def unhook(self, callback: EventCallback): + self._parent.unhook(callback) + + def on(self, subevent: str, *extra_subevents, + delimiter: str = DEFAULT_EVENT_DELIMITER) -> EventHook: + return self._parent._context_on(self.context, subevent, + extra_subevents, delimiter) + + def call_for_result(self, default=None, **kwargs) -> typing.Any: + return self._parent.call_for_result(default, **kwargs) + def assure_call(self, **kwargs): + self._parent.assure_call(**kwargs) + def call(self, **kwargs) -> typing.List[typing.Any]: + return self._parent.call(**kwargs) + def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]: + return self._parent.call_limited(maximum, **kwargs) + + def call_unsafe_for_result(self, default=None, **kwargs) -> typing.Any: + return self._parent.call_unsafe_for_result(default, **kwargs) + def call_unsafe(self, **kwargs) -> typing.List[typing.Any]: + return self._parent.call_unsafe(**kwargs) + def call_unsafe_limited(self, maximum: int, **kwargs + ) -> typing.List[typing.Any]: + return self._parent.call_unsafe_limited(maximum, **kwargs) + + def get_hooks(self) -> typing.List[EventCallback]: + return self._parent.get_hooks() + def get_children(self) -> typing.List[EventHook]: + return self._parent.get_children() diff --git a/src/Exports.py b/src/Exports.py index 8baca50d..68b25933 100644 --- a/src/Exports.py +++ b/src/Exports.py @@ -1,28 +1,18 @@ - - -class ExportsContext(object): - def __init__(self, parent, context): - self._parent = parent - self.context = context - - def add(self, setting, value): - self._parent._context_add(self.context, setting, value) - def get_all(self, setting): - return self._parent.get_all(setting) +import typing class Exports(object): def __init__(self): self._exports = {} self._context_exports = {} - def new_context(self, context): + def new_context(self, context: str) -> "ExportsContext": return ExportsContext(self, context) - def add(self, setting, value): + def add(self, setting: str, value: typing.Any): self._add(None, setting, value) - def _context_add(self, context, setting, value): + def _context_add(self, context: str, setting: str, value: typing.Any): self._add(context, setting, value) - def _add(self, context, setting, value): + def _add(self, context: str, setting: str, value: typing.Any): if context == None: if not setting in self_exports: self._exports[setting] = [] @@ -34,11 +24,21 @@ class Exports(object): self._context_exports[context][setting] = [] self._context_exports[context][setting].append(value) - def get_all(self, setting): + def get_all(self, setting: str) -> typing.List[typing.Any]: return self._exports.get(setting, []) + sum([ exports.get(setting, []) for exports in self._context_exports.values()], []) - def purge_context(self, context): + def purge_context(self, context: str): if context in self._context_exports: del self._context_exports[context] + +class ExportsContext(object): + def __init__(self, parent: Exports, context: str): + self._parent = parent + self.context = context + + def add(self, setting: str, value: typing.Any): + self._parent._context_add(self.context, setting, value) + def get_all(self, setting: str) -> typing.List[typing.Any]: + return self._parent.get_all(setting) diff --git a/src/IRCBot.py b/src/IRCBot.py index 657fb135..61f21114 100644 --- a/src/IRCBot.py +++ b/src/IRCBot.py @@ -1,4 +1,4 @@ -import os, select, socket, sys, threading, time, traceback, uuid +import os, select, socket, sys, threading, time, traceback, typing, uuid from src import EventManager, Exports, IRCServer, Logging, ModuleManager from src import Socket, utils @@ -28,14 +28,15 @@ class Bot(object): self._trigger_functions = [] - def trigger(self, func=None): + def trigger(self, func: typing.Callable[[], typing.Any]=None): self.lock.acquire() if func: self._trigger_functions.append(func) self._trigger_client.send(b"TRIGGER") self.lock.release() - def add_server(self, server_id, connect=True): + def add_server(self, server_id: int, connect: bool = True + ) -> typing.Optional[IRCServer.Server]: (_, alias, hostname, port, password, ipv4, tls, bindhost, nickname, username, realname) = self.database.servers.get(server_id) @@ -49,20 +50,20 @@ class Bot(object): self.connect(new_server) return new_server - def add_socket(self, sock): + def add_socket(self, sock: socket.socket): self.other_sockets[sock.fileno()] = sock self.poll.register(sock.fileno(), select.EPOLLIN) - def remove_socket(self, sock): + def remove_socket(self, sock: socket.socket): del self.other_sockets[sock.fileno()] self.poll.unregister(sock.fileno()) - def get_server(self, id): + def get_server(self, id: int) -> typing.Optional[IRCServer.Server]: for server in self.servers.values(): if server.id == id: return server - def connect(self, server): + def connect(self, server: IRCServer.Server) -> bool: try: server.connect() except: @@ -73,7 +74,7 @@ class Bot(object): self.poll.register(server.fileno(), select.EPOLLOUT) return True - def next_send(self): + def next_send(self) -> typing.Optional[float]: next = None for server in self.servers.values(): timeout = server.send_throttle_timeout() @@ -81,7 +82,7 @@ class Bot(object): next = timeout return next - def next_ping(self): + def next_ping(self) -> typing.Optional[float]: timeouts = [] for server in self.servers.values(): timeout = server.until_next_ping() @@ -90,7 +91,8 @@ class Bot(object): if not timeouts: return None return min(timeouts) - def next_read_timeout(self): + + def next_read_timeout(self) -> typing.Optional[float]: timeouts = [] for server in self.servers.values(): timeouts.append(server.until_read_timeout()) @@ -98,7 +100,7 @@ class Bot(object): return None return min(timeouts) - def get_poll_timeout(self): + def get_poll_timeout(self) -> float: timeouts = [] timeouts.append(self._timers.next()) timeouts.append(self.next_send()) @@ -107,15 +109,15 @@ class Bot(object): timeouts.append(self.cache.next_expiration()) return min([timeout for timeout in timeouts if not timeout == None]) - def register_read(self, server): + def register_read(self, server: IRCServer.Server): self.poll.modify(server.fileno(), select.EPOLLIN) - def register_write(self, server): + def register_write(self, server: IRCServer.Server): self.poll.modify(server.fileno(), select.EPOLLOUT) - def register_both(self, server): + def register_both(self, server: IRCServer.Server): self.poll.modify(server.fileno(), select.EPOLLIN|select.EPOLLOUT) - def disconnect(self, server): + def disconnect(self, server: IRCServer.Server): try: self.poll.unregister(server.fileno()) except FileNotFoundError: @@ -123,23 +125,25 @@ class Bot(object): del self.servers[server.fileno()] @utils.hook("timer.reconnect") - def reconnect(self, event): + def reconnect(self, event: EventManager.Event): server = self.add_server(event["server_id"], False) if self.connect(server): self.servers[server.fileno()] = server else: event["timer"].redo() - def set_setting(self, setting, value): + def set_setting(self, setting: str, value: typing.Any): self.database.bot_settings.set(setting, value) - def get_setting(self, setting, default=None): + def get_setting(self, setting: str, default: typing.Any=None) -> typing.Any: return self.database.bot_settings.get(setting, default) - def find_settings(self, pattern, default=[]): + def find_settings(self, pattern: str, default: typing.Any=[] + ) -> typing.List[typing.Any]: return self.database.bot_settings.find(pattern, default) - def find_settings_prefix(self, prefix, default=[]): + def find_settings_prefix(self, prefix: str, default: typing.Any=[] + ) -> typing.List[typing.Any]: return self.database.bot_settings.find_prefix( prefix, default) - def del_setting(self, setting): + def del_setting(self, setting: str): self.database.bot_settings.delete(setting) def run(self): diff --git a/src/IRCBuffer.py b/src/IRCBuffer.py index 06465749..24fde7bf 100644 --- a/src/IRCBuffer.py +++ b/src/IRCBuffer.py @@ -1,8 +1,9 @@ -import re -from src import utils +import re, typing +from src import IRCBot, utils class BufferLine(object): - def __init__(self, sender, message, action, tags, from_self, method): + def __init__(self, sender: str, message: str, action: bool, tags: dict, + from_self: bool, method: str): self.sender = sender self.message = message self.action = action @@ -11,35 +12,39 @@ class BufferLine(object): self.method = method class Buffer(object): - def __init__(self, bot, server): + def __init__(self, bot: "IRCBot.Bot", server: "IRCServer.Server"): self.bot = bot self.server = server self.lines = [] self.max_lines = 64 self._skip_next = False - def _add_message(self, sender, message, action, tags, from_self, method): + def _add_message(self, sender: str, message: str, action: bool, tags: dict, + from_self: bool, method: str): if not self._skip_next: line = BufferLine(sender, message, action, tags, from_self, method) self.lines.insert(0, line) if len(self.lines) > self.max_lines: self.lines.pop() self._skip_next = False - def add_message(self, sender, message, action, tags, from_self=False): + def add_message(self, sender: str, message: str, action: bool, tags: dict, + from_self: bool=False): self._add_message(sender, message, action, tags, from_self, "PRIVMSG") - def add_notice(self, sender, message, tags, from_self=False): + def add_notice(self, sender: str, message: str, tags: dict, + from_self: bool=False): self._add_message(sender, message, False, tags, from_self, "NOTICE") - def get(self, index=0, **kwargs): + def get(self, index: int=0, **kwargs) -> typing.Optional[BufferLine]: from_self = kwargs.get("from_self", True) for line in self.lines: if line.from_self and not from_self: continue return line - def find(self, pattern, **kwargs): + def find(self, pattern: typing.Union[str, typing.Pattern[str]], **kwargs + ) -> typing.Optional[BufferLine]: from_self = kwargs.get("from_self", True) for_user = kwargs.get("for_user", "") - for_user = utils.irc.lower(self.server, for_user + for_user = utils.irc.lower(self.server.case_mapping, for_user ) if for_user else None not_pattern = kwargs.get("not_pattern", None) for line in self.lines: @@ -48,8 +53,8 @@ class Buffer(object): elif re.search(pattern, line.message): if not_pattern and re.search(not_pattern, line.message): continue - if for_user and not utils.irc.lower(self.server, line.sender - ) == for_user: + if for_user and not utils.irc.lower(self.server.case_mapping, + line.sender) == for_user: continue return line def skip_next(self): diff --git a/src/IRCChannel.py b/src/IRCChannel.py index ffe34fa9..7d2c38fa 100644 --- a/src/IRCChannel.py +++ b/src/IRCChannel.py @@ -1,9 +1,10 @@ -import uuid -from src import IRCBuffer, IRCObject, utils +import typing, uuid +from src import IRCBot, IRCBuffer, IRCObject, IRCServer, IRCUser, utils class Channel(IRCObject.Object): - def __init__(self, name, id, server, bot): - self.name = utils.irc.lower(server, name) + def __init__(self, name: str, id, server: "IRCServer.Server", + bot: "IRCBot.Bot"): + self.name = utils.irc.lower(server.case_mapping, name) self.id = id self.server = server self.bot = bot @@ -18,23 +19,24 @@ class Channel(IRCObject.Object): self.created_timestamp = None self.buffer = IRCBuffer.Buffer(bot, server) - def __repr__(self): + def __repr__(self) -> str: return "IRCChannel.Channel(%s|%s)" % (self.server.name, self.name) - def __str__(self): + def __str__(self) -> str: return self.name - def set_topic(self, topic): + def set_topic(self, topic: str): self.topic = topic - def set_topic_setter(self, nickname, username=None, hostname=None): + def set_topic_setter(self, nickname: str, username: str=None, + hostname: str=None): self.topic_setter_nickname = nickname self.topic_setter_username = username self.topic_setter_hostname = hostname - def set_topic_time(self, unix_timestamp): + def set_topic_time(self, unix_timestamp: int): self.topic_time = unix_timestamp - def add_user(self, user): + def add_user(self, user: IRCUser.User): self.users.add(user) - def remove_user(self, user): + def remove_user(self, user: IRCUser.User): self.users.remove(user) for mode in list(self.modes.keys()): if mode in self.server.prefix_modes and user in self.modes[mode]: @@ -43,10 +45,10 @@ class Channel(IRCObject.Object): del self.modes[mode] if user in self.user_modes: del self.user_modes[user] - def has_user(self, user): + def has_user(self, user: IRCUser.User) -> bool: return user in self.users - def add_mode(self, mode, arg=None): + def add_mode(self, mode: str, arg: str=None): if not mode in self.modes: self.modes[mode] = set([]) if arg: @@ -59,7 +61,7 @@ class Channel(IRCObject.Object): self.user_modes[user].add(mode) else: self.modes[mode].add(arg.lower()) - def remove_mode(self, mode, arg=None): + def remove_mode(self, mode: str, arg: str=None): if not arg: del self.modes[mode] else: @@ -76,63 +78,70 @@ class Channel(IRCObject.Object): self.modes[mode].discard(arg.lower()) if not len(self.modes[mode]): del self.modes[mode] - def change_mode(self, remove, mode, arg=None): + def change_mode(self, remove: bool, mode: str, arg: str=None): if remove: self.remove_mode(mode, arg) else: self.add_mode(mode, arg) - def set_setting(self, setting, value): + def set_setting(self, setting: str, value: typing.Any): self.bot.database.channel_settings.set(self.id, setting, value) - def get_setting(self, setting, default=None): + def get_setting(self, setting: str, default: typing.Any=None + ) -> typing.Any: return self.bot.database.channel_settings.get(self.id, setting, default) - def find_settings(self, pattern, default=[]): + def find_settings(self, pattern: str, default: typing.Any=[] + ) -> typing.List[typing.Any]: return self.bot.database.channel_settings.find(self.id, pattern, default) - def find_settings_prefix(self, prefix, default=[]): + def find_settings_prefix(self, prefix: str, default: typing.Any=[] + ) -> typing.List[typing.Any]: return self.bot.database.channel_settings.find_prefix(self.id, prefix, default) - def del_setting(self, setting): + def del_setting(self, setting: str): self.bot.database.channel_settings.delete(self.id, setting) - def set_user_setting(self, user_id, setting, value): + def set_user_setting(self, user_id: int, setting: str, value: typing.Any): self.bot.database.user_channel_settings.set(user_id, self.id, setting, value) - def get_user_setting(self, user_id, setting, default=None): + def get_user_setting(self, user_id: int, setting: str, + default: typing.Any=None) -> typing.Any: return self.bot.database.user_channel_settings.get(user_id, self.id, setting, default) - def find_user_settings(self, user_i, pattern, default=[]): + def find_user_settings(self, user_id: int, pattern: str, + default: typing.Any=[]) -> typing.List[typing.Any]: return self.bot.database.user_channel_settings.find(user_id, self.id, pattern, default) - def find_user_settings_prefix(self, user_id, prefix, default=[]): + def find_user_settings_prefix(self, user_id: int, prefix: str, + default: typing.Any=[]) -> typing.List[typing.Any]: return self.bot.database.user_channel_settings.find_prefix( user_id, self.id, prefix, default) - def del_user_setting(self, user_id, setting): + def del_user_setting(self, user_id: int, setting: str): self.bot.database.user_channel_settings.delete(user_id, self.id, setting) - def find_all_by_setting(self, setting, default=[]): + def find_all_by_setting(self, setting: str, default: typing.Any=[] + ) -> typing.List[typing.Any]: return self.bot.database.user_channel_settings.find_all_by_setting( self.id, setting, default) - def send_message(self, text, prefix=None, tags={}): + def send_message(self, text: str, prefix: str=None, tags: dict={}): self.server.send_message(self.name, text, prefix=prefix, tags=tags) - def send_notice(self, text, prefix=None, tags={}): + def send_notice(self, text: str, prefix: str=None, tags: dict={}): self.server.send_notice(self.name, text, prefix=prefix, tags=tags) - def send_mode(self, mode=None, target=None): + def send_mode(self, mode: str=None, target: str=None): self.server.send_mode(self.name, mode, target) - def send_kick(self, target, reason=None): + def send_kick(self, target: str, reason: str=None): self.server.send_kick(self.name, target, reason) - def send_ban(self, hostmask): + def send_ban(self, hostmask: str): self.server.send_mode(self.name, "+b", hostmask) - def send_unban(self, hostmask): + def send_unban(self, hostmask: str): self.server.send_mode(self.name, "-b", hostmask) - def send_topic(self, topic): + def send_topic(self, topic: str): self.server.send_topic(self.name, topic) - def send_part(self, reason=None): + def send_part(self, reason: str=None): self.server.send_part(self.name, reason) - def mode_or_above(self, user, mode): + def mode_or_above(self, user: IRCUser.User, mode: str) -> bool: mode_orders = list(self.server.prefix_modes) mode_index = mode_orders.index(mode) for mode in mode_orders[:mode_index+1]: @@ -140,8 +149,8 @@ class Channel(IRCObject.Object): return True return False - def has_mode(self, user, mode): + def has_mode(self, user: IRCUser.User, mode: str) -> bool: return user in self.modes.get(mode, []) - def get_user_status(self, user): + def get_user_status(self, user: IRCUser.User) -> typing.Set: return self.user_modes.get(user, []) diff --git a/src/IRCServer.py b/src/IRCServer.py index e92cd291..16ad839a 100644 --- a/src/IRCServer.py +++ b/src/IRCServer.py @@ -1,5 +1,5 @@ -import collections, socket, ssl, sys, time -from src import IRCChannel, IRCObject, IRCUser, utils +import collections, socket, ssl, sys, time, typing +from src import EventManager, IRCBot, IRCChannel, IRCObject, IRCUser, utils THROTTLE_LINES = 4 THROTTLE_SECONDS = 1 @@ -7,8 +7,12 @@ READ_TIMEOUT_SECONDS = 120 PING_INTERVAL_SECONDS = 30 class Server(IRCObject.Object): - def __init__(self, bot, events, id, alias, hostname, port, password, - ipv4, tls, bindhost, nickname, username, realname): + def __init__(self, + bot: "IRCBot.Bot", + events: EventManager.EventHook, + id: int, alias: str, hostname: str, port: int, password: str, + ipv4: bool, tls: bool, bindhost: str, + nickname: str, username: str, realname: str): self.connected = False self.bot = bot self.events = events @@ -121,77 +125,80 @@ class Server(IRCObject.Object): except: pass - def set_setting(self, setting, value): + def set_setting(self, setting: str, value: typing.Any): self.bot.database.server_settings.set(self.id, setting, value) - def get_setting(self, setting, default=None): + def get_setting(self, setting: str, default: typing.Any=None): return self.bot.database.server_settings.get(self.id, setting, default) - def find_settings(self, pattern, default=[]): + def find_settings(self, pattern: str, default: typing.Any=[]): return self.bot.database.server_settings.find(self.id, pattern, default) - def find_settings_prefix(self, prefix, default=[]): + def find_settings_prefix(self, prefix: str, default: typing.Any=[]): return self.bot.database.server_settings.find_prefix( self.id, prefix, default) - def del_setting(self, setting): + def del_setting(self, setting: str): self.bot.database.server_settings.delete(self.id, setting) - def get_user_setting(self, nickname, setting, default=None): + def get_user_setting(self, nickname: str, setting: str, + default: typing.Any=None): user_id = self.get_user_id(nickname) return self.bot.database.user_settings.get(user_id, setting, default) - def set_user_setting(self, nickname, setting, value): + def set_user_setting(self, nickname: str, setting: str, value: typing.Any): user_id = self.get_user_id(nickname) self.bot.database.user_settings.set(user_id, setting, value) - def get_all_user_settings(self, setting, default=[]): + def get_all_user_settings(self, setting: str, default: typing.Any=[]): return self.bot.database.user_settings.find_all_by_setting( self.id, setting, default) - def find_all_user_channel_settings(self, setting, default=[]): + def find_all_user_channel_settings(self, setting: str, + default: typing.Any=[]): return self.bot.database.user_channel_settings.find_all_by_setting( self.id, setting, default) - def set_own_nickname(self, nickname): + def set_own_nickname(self, nickname: str): self.nickname = nickname - self.nickname_lower = utils.irc.lower(self, nickname) - def is_own_nickname(self, nickname): + self.nickname_lower = utils.irc.lower(self.case_mapping, nickname) + def is_own_nickname(self, nickname: str): return utils.irc.equals(self, nickname, self.nickname) - def add_own_mode(self, mode, arg=None): + def add_own_mode(self, mode: str, arg: str=None): self.own_modes[mode] = arg - def remove_own_mode(self, mode): + def remove_own_mode(self, mode: str): del self.own_modes[mode] - def change_own_mode(self, remove, mode, arg=None): + def change_own_mode(self, remove: bool, mode: str, arg: str=None): if remove: self.remove_own_mode(mode) else: self.add_own_mode(mode, arg) - def has_user(self, nickname): - return utils.irc.lower(self, nickname) in self.users - def get_user(self, nickname, create=True): + def has_user(self, nickname: str): + return utils.irc.lower(self.case_mapping, nickname) in self.users + def get_user(self, nickname: str, create: bool=True): if not self.has_user(nickname) and create: user_id = self.get_user_id(nickname) new_user = IRCUser.User(nickname, user_id, self, self.bot) self.events.on("new.user").call(user=new_user, server=self) self.users[new_user.nickname_lower] = new_user self.new_users.add(new_user) - return self.users.get(utils.irc.lower(self, nickname), None) - def get_user_id(self, nickname): + return self.users.get(utils.irc.lower(self.case_mapping, nickname), + None) + def get_user_id(self, nickname: str): self.bot.database.users.add(self.id, nickname) return self.bot.database.users.get_id(self.id, nickname) - def remove_user(self, user): + def remove_user(self, user: IRCUser.User): del self.users[user.nickname_lower] for channel in user.channels: channel.remove_user(user) - def change_user_nickname(self, old_nickname, new_nickname): - user = self.users.pop(utils.irc.lower(self, old_nickname)) + def change_user_nickname(self, old_nickname: str, new_nickname: str): + user = self.users.pop(utils.irc.lower(self.case_mapping, old_nickname)) user._id = self.get_user_id(new_nickname) - self.users[utils.irc.lower(self, new_nickname)] = user - def has_channel(self, channel_name): + self.users[utils.irc.lower(self.case_mapping, new_nickname)] = user + def has_channel(self, channel_name: str): return channel_name[0] in self.channel_types and utils.irc.lower( - self, channel_name) in self.channels - def get_channel(self, channel_name): + self.case_mapping, channel_name) in self.channels + def get_channel(self, channel_name: str): if not self.has_channel(channel_name): channel_id = self.get_channel_id(channel_name) new_channel = IRCChannel.Channel(channel_name, channel_id, @@ -199,15 +206,15 @@ class Server(IRCObject.Object): self.events.on("new.channel").call(channel=new_channel, server=self) self.channels[new_channel.name] = new_channel - return self.channels[utils.irc.lower(self, channel_name)] - def get_channel_id(self, channel_name): + return self.channels[utils.irc.lower(self.case_mapping, channel_name)] + def get_channel_id(self, channel_name: str): self.bot.database.channels.add(self.id, channel_name) return self.bot.database.channels.get_id(self.id, channel_name) - def remove_channel(self, channel): + def remove_channel(self, channel: IRCChannel.Channel): for user in channel.users: user.part_channel(channel) del self.channels[channel.name] - def parse_data(self, line): + def parse_data(self, line: str): if not line: return self.events.on("raw").call_unsafe(server=self, line=line) @@ -271,7 +278,7 @@ class Server(IRCObject.Object): def read_timed_out(self): return self.until_read_timeout == 0 - def send(self, data): + def send(self, data: str): returned = self.events.on("preprocess.send").call_unsafe_for_result( server=self, line=data) line = returned or data @@ -314,16 +321,16 @@ class Server(IRCObject.Object): time_left = time_left-now return time_left - def send_user(self, username, realname): + def send_user(self, username: str, realname: str): self.send("USER %s 0 * :%s" % (username, realname)) - def send_nick(self, nickname): + def send_nick(self, nickname: str): self.send("NICK %s" % nickname) def send_capibility_ls(self): self.send("CAP LS 302") - def queue_capability(self, capability): + def queue_capability(self, capability: str): self._capability_queue.add(capability) - def queue_capabilities(self, capabilities): + def queue_capabilities(self, capabilities: typing.List[str]): self._capability_queue.update(capabilities) def send_capability_queue(self): if self.has_capability_queue(): @@ -332,46 +339,46 @@ class Server(IRCObject.Object): self.send_capability_request(capabilities) def has_capability_queue(self): return bool(len(self._capability_queue)) - def send_capability_request(self, capability): + def send_capability_request(self, capability: str): self.send("CAP REQ :%s" % capability) def send_capability_end(self): self.send("CAP END") - def send_authenticate(self, text): + def send_authenticate(self, text: str): self.send("AUTHENTICATE %s" % text) def send_starttls(self): self.send("STARTTLS") def waiting_for_capabilities(self): return bool(len(self._capabilities_waiting)) - def wait_for_capability(self, capability): + def wait_for_capability(self, capability: str): self._capabilities_waiting.add(capability) - def capability_done(self, capability): + def capability_done(self, capability: str): self._capabilities_waiting.remove(capability) if not self._capabilities_waiting: self.send_capability_end() - def send_pass(self, password): + def send_pass(self, password: str): self.send("PASS %s" % password) - def send_ping(self, nonce="hello"): + def send_ping(self, nonce: str="hello"): self.send("PING :%s" % nonce) - def send_pong(self, nonce="hello"): + def send_pong(self, nonce: str="hello"): self.send("PONG :%s" % nonce) - def try_rejoin(self, event): + def try_rejoin(self, event: EventManager.Event): if event["server_id"] == self.id and event["channel_name" ] in self.attempted_join: self.send_join(event["channel_name"], event["key"]) - def send_join(self, channel_name, key=None): + def send_join(self, channel_name: str, key: str=None): self.send("JOIN %s%s" % (channel_name, "" if key == None else " %s" % key)) - def send_part(self, channel_name, reason=None): + def send_part(self, channel_name: str, reason: str=None): self.send("PART %s%s" % (channel_name, "" if reason == None else " %s" % reason)) - def send_quit(self, reason="Leaving"): + def send_quit(self, reason: str="Leaving"): self.send("QUIT :%s" % reason) - def _tag_str(self, tags): + def _tag_str(self, tags: dict): tag_str = "" for tag, value in tags.items(): if tag_str: @@ -383,7 +390,8 @@ class Server(IRCObject.Object): tag_str = "@%s " % tag_str return tag_str - def send_message(self, target, message, prefix=None, tags={}): + def send_message(self, target: str, message: str, prefix: str=None, + tags: dict={}): full_message = message if not prefix else prefix+message self.send("%sPRIVMSG %s :%s" % (self._tag_str(tags), target, full_message)) @@ -408,7 +416,8 @@ class Server(IRCObject.Object): message=full_message, message_split=full_message_split, user=user, action=action, server=self) - def send_notice(self, target, message, prefix=None, tags={}): + def send_notice(self, target: str, message: str, prefix: str=None, + tags: dict={}): full_message = message if not prefix else prefix+message self.send("%sNOTICE %s :%s" % (self._tag_str(tags), target, full_message)) @@ -419,31 +428,31 @@ class Server(IRCObject.Object): self.get_user(target).buffer.add_notice(None, message, tags, True) - def send_mode(self, target, mode=None, args=None): + def send_mode(self, target: str, mode: str=None, args: str=None): self.send("MODE %s%s%s" % (target, "" if mode == None else " %s" % mode, "" if args == None else " %s" % args)) - def send_topic(self, channel_name, topic): + def send_topic(self, channel_name: str, topic: str): self.send("TOPIC %s :%s" % (channel_name, topic)) - def send_kick(self, channel_name, target, reason=None): + def send_kick(self, channel_name: str, target: str, reason: str=None): self.send("KICK %s %s%s" % (channel_name, target, "" if reason == None else " :%s" % reason)) - def send_names(self, channel_name): + def send_names(self, channel_name: str): self.send("NAMES %s" % channel_name) - def send_list(self, search_for=None): + def send_list(self, search_for: str=None): self.send( "LIST%s" % "" if search_for == None else " %s" % search_for) - def send_invite(self, target, channel_name): + def send_invite(self, target: str, channel_name: str): self.send("INVITE %s %s" % (target, channel_name)) - def send_whois(self, target): + def send_whois(self, target: str): self.send("WHOIS %s" % target) - def send_whowas(self, target, amount=None, server=None): + def send_whowas(self, target: str, amount: int=None, server: str=None): self.send("WHOWAS %s%s%s" % (target, "" if amount == None else " %s" % amount, "" if server == None else " :%s" % server)) - def send_who(self, filter=None): + def send_who(self, filter: str=None): self.send("WHO%s" % ("" if filter == None else " %s" % filter)) - def send_whox(self, mask, filter, fields, label=None): + def send_whox(self, mask: str, filter: str, fields: str, label: str=None): self.send("WHO %s %s%%%s%s" % (mask, filter, fields, ","+label if label else "")) diff --git a/src/IRCUser.py b/src/IRCUser.py index c0bbb862..edead0d2 100644 --- a/src/IRCUser.py +++ b/src/IRCUser.py @@ -1,8 +1,9 @@ -import uuid -from src import IRCBuffer, IRCObject, utils +import typing, uuid +from src import IRCBot, IRCChannel, IRCBuffer, IRCObject, IRCServer, utils class User(IRCObject.Object): - def __init__(self, nickname, id, server, bot): + def __init__(self, nickname: str, id: int, server: "IRCServer.Server", + bot: "IRCBot.Bot"): self.server = server self.set_nickname(nickname) self._id = id @@ -20,46 +21,51 @@ class User(IRCObject.Object): self.away = False self.buffer = IRCBuffer.Buffer(bot, server) - def __repr__(self): + def __repr__(self) -> str: return "IRCUser.User(%s|%s)" % (self.server.name, self.name) - def __str__(self): + def __str__(self) -> str: return self.nickname - def get_id(self): + def get_id(self)-> int: return (self.identified_account_id_override or self.identified_account_id or self._id) - def get_identified_account(self): + def get_identified_account(self) -> str: return (self.identified_account_override or self.identified_account) - def set_nickname(self, nickname): + def set_nickname(self, nickname: str): self.nickname = nickname - self.nickname_lower = utils.irc.lower(self.server, nickname) + self.nickname_lower = utils.irc.lower(self.server.case_mapping, + nickname) self.name = self.nickname_lower - def join_channel(self, channel): + def join_channel(self, channel: "IRCChannel.Channel"): self.channels.add(channel) - def part_channel(self, channel): + def part_channel(self, channel: "IRCChannel.Channel"): self.channels.remove(channel) - def set_setting(self, setting, value): + + def set_setting(self, setting: str, value: typing.Any): self.bot.database.user_settings.set(self.get_id(), setting, value) - def get_setting(self, setting, default=None): + def get_setting(self, setting: str, default: typing.Any=None) -> typing.Any: return self.bot.database.user_settings.get(self.get_id(), setting, default) - def find_settings(self, pattern, default=[]): + def find_settings(self, pattern: str, default: typing.Any=[] + ) -> typing.List[typing.Any]: return self.bot.database.user_settings.find(self.get_id(), pattern, default) - def find_settings_prefix(self, prefix, default=[]): + def find_settings_prefix(self, prefix: str, default: typing.Any=[] + ) -> typing.List[typing.Any]: return self.bot.database.user_settings.find_prefix(self.get_id(), prefix, default) def del_setting(self, setting): self.bot.database.user_settings.delete(self.get_id(), setting) - def get_channel_settings_per_setting(self, setting, default=[]): + def get_channel_settings_per_setting(self, setting: str, + default: typing.Any=[]) -> typing.List[typing.Any]: return self.bot.database.user_channel_settings.find_by_setting( self.get_id(), setting, default) - def send_message(self, message, prefix=None, tags={}): + def send_message(self, message: str, prefix: str=None, tags: dict={}): self.server.send_message(self.nickname, message, prefix=prefix, tags=tags) - def send_notice(self, text, prefix=None, tags={}): + def send_notice(self, text: str, prefix: str=None, tags: dict={}): self.server.send_notice(self.nickname, text, prefix=prefix, tags=tags) - def send_ctcp_response(self, command, args): + def send_ctcp_response(self, command: str, args: str): self.send_notice("\x01%s %s\x01" % (command, args)) diff --git a/src/Logging.py b/src/Logging.py index 6ea6efe8..d5c42e56 100644 --- a/src/Logging.py +++ b/src/Logging.py @@ -1,4 +1,4 @@ -import logging, logging.handlers, os, sys, time +import logging, logging.handlers, os, sys, time, typing LEVELS = { "trace": logging.DEBUG-1, @@ -23,7 +23,7 @@ class BitBotFormatter(logging.Formatter): return s class Log(object): - def __init__(self, level, location): + def __init__(self, level: str, location: str): logging.addLevelName(LEVELS["trace"], "TRACE") self.logger = logging.getLogger(__name__) @@ -49,17 +49,17 @@ class Log(object): file_handler.setFormatter(formatter) self.logger.addHandler(file_handler) - def trace(self, message, params, **kwargs): + def trace(self, message: str, params: typing.List, **kwargs): self._log(message, params, LEVELS["trace"], kwargs) - def debug(self, message, params, **kwargs): + def debug(self, message: str, params: typing.List, **kwargs): self._log(message, params, logging.DEBUG, kwargs) - def info(self, message, params, **kwargs): + def info(self, message: str, params: typing.List, **kwargs): self._log(message, params, logging.INFO, kwargs) - def warn(self, message, params, **kwargs): + def warn(self, message: str, params: typing.List, **kwargs): self._log(message, params, logging.WARN, kwargs) - def error(self, message, params, **kwargs): + def error(self, message: str, params: typing.List, **kwargs): self._log(message, params, logging.ERROR, kwargs) - def critical(self, message, params, **kwargs): + def critical(self, message: str, params: typing.List, **kwargs): self._log(message, params, logging.CRITICAL, kwargs) - def _log(self, message, params, level, kwargs): + def _log(self, message: str, params: typing.List, level: int, kwargs: dict): self.logger.log(level, message, *params, **kwargs) diff --git a/src/ModuleManager.py b/src/ModuleManager.py index 8f182c8d..f8ddd077 100644 --- a/src/ModuleManager.py +++ b/src/ModuleManager.py @@ -1,8 +1,5 @@ -import gc, glob, imp, io, inspect, os, sys, uuid -from . import utils - -BITBOT_HOOKS_MAGIC = "__bitbot_hooks" -BITBOT_EXPORTS_MAGIC = "__bitbot_exports" +import gc, glob, imp, io, inspect, os, sys, typing, uuid +from src import Config, EventManager, Exports, IRCBot, Logging, Timers, utils class ModuleException(Exception): pass @@ -22,7 +19,11 @@ class ModuleNotLoadedWarning(ModuleWarning): pass class BaseModule(object): - def __init__(self, bot, events, exports, timers): + def __init__(self, + bot: "IRCBot.Bot", + events: EventManager.EventHook, + exports: Exports.Exports, + timers: Timers.Timers): self.bot = bot self.events = events self.exports = exports @@ -32,7 +33,13 @@ class BaseModule(object): pass class ModuleManager(object): - def __init__(self, events, exports, timers, config, log, directory): + def __init__(self, + events: EventManager.EventHook, + exports: Exports.Exports, + timers: Timers.Timers, + config: Config.Config, + log: Logging.Log, + directory: str): self.events = events self.exports = exports self.config = config @@ -43,23 +50,24 @@ class ModuleManager(object): self.modules = {} self.waiting_requirement = {} - def list_modules(self): + def list_modules(self) -> typing.List[str]: return sorted(glob.glob(os.path.join(self.directory, "*.py"))) - def _module_name(self, path): + def _module_name(self, path: str) -> str: return os.path.basename(path).rsplit(".py", 1)[0].lower() - def _module_path(self, name): + def _module_path(self, name: str) -> str: return os.path.join(self.directory, "%s.py" % name) - def _import_name(self, name): + def _import_name(self, name: str) -> str: return "bitbot_%s" % name - def _get_magic(self, obj, magic, default): + def _get_magic(self, obj: typing.Any, magic: str, default: typing.Any + ) -> typing.Any: return getattr(obj, magic) if hasattr(obj, magic) else default - def _load_module(self, bot, name): + def _load_module(self, bot: "IRCBot.Bot", name: str): path = self._module_path(name) - for hashflag, value in utils.get_hashflags(path): + for hashflag, value in utils.parse.hashflags(path): if hashflag == "ignore": # nope, ignore this module. raise ModuleNotLoadedWarning("module ignored") @@ -97,10 +105,12 @@ class ModuleManager(object): module_object._name = name.title() for attribute_name in dir(module_object): attribute = getattr(module_object, attribute_name) - for hook in self._get_magic(attribute, BITBOT_HOOKS_MAGIC, []): + for hook in self._get_magic(attribute, + utils.consts.BITBOT_HOOKS_MAGIC, []): context_events.on(hook["event"]).hook(attribute, **hook["kwargs"]) - for export in self._get_magic(module_object, BITBOT_EXPORTS_MAGIC, []): + for export in self._get_magic(module_object, + utils.consts.BITBOT_EXPORTS_MAGIC, []): context_exports.add(export["setting"], export["value"]) module_object._context = context @@ -111,7 +121,7 @@ class ModuleManager(object): "attempted to be used twice") return module_object - def load_module(self, bot, name): + def load_module(self, bot: "IRCBot.Bot", name: str): try: module = self._load_module(bot, name) except ModuleWarning as warning: @@ -128,7 +138,8 @@ class ModuleManager(object): self.load_module(bot, requirement_name) self.log.info("Module '%s' loaded", [name]) - def load_modules(self, bot, whitelist=[], blacklist=[]): + def load_modules(self, bot: "IRCBot.Bot", whitelist: typing.List[str]=[], + blacklist: typing.List[str]=[]): for path in self.list_modules(): name = self._module_name(path) if name in whitelist or (not whitelist and not name in blacklist): @@ -137,7 +148,7 @@ class ModuleManager(object): except ModuleWarning: pass - def unload_module(self, name): + def unload_module(self, name: str): if not name in self.modules: raise ModuleNotFoundException() module = self.modules[name] diff --git a/src/Socket.py b/src/Socket.py index f405b0a9..474336d0 100644 --- a/src/Socket.py +++ b/src/Socket.py @@ -1,7 +1,9 @@ - +import socket, typing class Socket(object): - def __init__(self, socket, on_read, encoding="utf8"): + def __init__(self, socket: socket.socket, + on_read: typing.Callable[["Socket", str], None], + encoding: str="utf8"): self.socket = socket self._on_read = on_read self.encoding = encoding @@ -12,18 +14,18 @@ class Socket(object): self.length = None self.connected = True - def fileno(self): + def fileno(self) -> int: return self.socket.fileno() def disconnect(self): self.connected = False - def _decode(self, s): - return s.decode(self.encoding) if self.encoding else s - def _encode(self, s): - return s.encode(self.encoding) if self.encoding else s + def _decode(self, s: bytes) -> str: + return s.decode(self.encoding) + def _encode(self, s: str) -> bytes: + return s.encode(self.encoding) - def read(self): + def read(self) -> typing.Optional[typing.List[str]]: data = self.socket.recv(1024) if not data: return None @@ -35,17 +37,17 @@ class Socket(object): if data_split[-1]: self._read_buffer = data_split.pop(-1) return [self._decode(data) for data in data_split] - return [data.decode(self.encoding)] + return [self._decode(data)] - def parse_data(self, data): + def parse_data(self, data: str): self._on_read(self, data) - def send(self, data): + def send(self, data: str): self._write_buffer += self._encode(data) def _send(self): self._write_buffer = self._write_buffer[self.socket.send( self._write_buffer):] - def waiting_send(self): + def waiting_send(self) -> bool: return bool(len(self._write_buffer)) diff --git a/src/Timers.py b/src/Timers.py index e49e7edc..c3336d87 100644 --- a/src/Timers.py +++ b/src/Timers.py @@ -1,7 +1,9 @@ -import time, uuid +import time, typing, uuid +from src import Database, EventManager, Logging class Timer(object): - def __init__(self, id, context, name, delay, next_due, kwargs): + def __init__(self, id: int, context: str, name: str, delay: float, + next_due: float, kwargs: dict): self.id = id self.context = context self.name = name @@ -15,9 +17,9 @@ class Timer(object): def set_next_due(self): self.next_due = time.time()+self.delay - def due(self): + def due(self) -> bool: return self.time_left() <= 0 - def time_left(self): + def time_left(self) -> float: return self.next_due-time.time() def redo(self): @@ -25,42 +27,33 @@ class Timer(object): self.set_next_due() def finish(self): self._done = True - def done(self): + def done(self) -> bool: return self._done -class TimersContext(object): - def __init__(self, parent, context): - self._parent = parent - self.context = context - def add(self, name, delay, next_due=None, **kwargs): - self._parent._add(self.context, name, delay, next_due, None, False, - kwargs) - def add_persistent(self, name, delay, next_due=None, **kwargs): - self._parent._add(None, name, delay, next_due, None, True, - kwargs) - class Timers(object): - def __init__(self, database, events, log): + def __init__(self, database: Database.Database, + events: EventManager.EventHook, + log: Logging.Log): self.database = database self.events = events self.log = log self.timers = [] self.context_timers = {} - def new_context(self, context): + def new_context(self, context: str) -> "TimersContext": return TimersContext(self, context) - def setup(self, timers): + def setup(self, timers: typing.List[typing.Tuple[str, dict]]): for name, timer in timers: id = name.split("timer-", 1)[1] self._add(timer["name"], None, timer["delay"], timer[ "next-due"], id, False, timer["kwargs"]) - def _persist(self, timer): + def _persist(self, timer: Timer): self.database.bot_settings.set("timer-%s" % timer.id, { "name": timer.name, "delay": timer.delay, "next-due": timer.next_due, "kwargs": timer.kwargs}) - def _remove(self, timer): + def _remove(self, timer: Timer): if timer.context: self.context_timers[timer.context].remove(timer) if not self.context_timers[timer.context]: @@ -69,11 +62,13 @@ class Timers(object): self.timers.remove(timer) self.database.bot_settings.delete("timer-%s" % timer.id) - def add(self, name, delay, next_due=None, **kwargs): + def add(self, name: str, delay: float, next_due: float=None, **kwargs): self._add(None, name, delay, next_due, None, False, kwargs) - def add_persistent(self, name, delay, next_due=None, **kwargs): + def add_persistent(self, name: str, delay: float, next_due: float=None, + **kwargs): self._add(None, name, delay, next_due, None, True, kwargs) - def _add(self, context, name, delay, next_due, id, persist, kwargs): + def _add(self, context: str, name: str, delay: float, next_due: float, + id: str, persist: bool, kwargs: dict): id = id or uuid.uuid4().hex timer = Timer(id, context, name, delay, next_due, kwargs) if persist: @@ -86,13 +81,13 @@ class Timers(object): else: self.timers.append(timer) - def next(self): + def next(self) -> float: times = filter(None, [timer.time_left() for timer in self.get_timers()]) if not times: return None return max(min(times), 0) - def get_timers(self): + def get_timers(self) -> typing.List[Timer]: return self.timers + sum(self.context_timers.values(), []) def call(self): @@ -104,6 +99,19 @@ class Timers(object): if timer.done(): self._remove(timer) - def purge_context(self, context): + def purge_context(self, context: str): if context in self.context_timers: del self.context_timers[context] + +class TimersContext(object): + def __init__(self, parent: Timers, context: str): + self._parent = parent + self.context = context + def add(self, name: str, delay: float, next_due: float=None, + **kwargs): + self._parent._add(self.context, name, delay, next_due, None, False, + kwargs) + def add_persistent(self, name: str, delay: float, next_due: float=None, + **kwargs): + self._parent._add(None, name, delay, next_due, None, True, + kwargs) diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 87671108..52de64ba 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,6 +1,5 @@ -import decimal, io, re -from src import ModuleManager -from . import irc, http +import decimal, io, re, typing +from src.utils import consts, irc, http, parse TIME_SECOND = 1 TIME_MINUTE = TIME_SECOND*60 @@ -8,7 +7,7 @@ TIME_HOUR = TIME_MINUTE*60 TIME_DAY = TIME_HOUR*24 TIME_WEEK = TIME_DAY*7 -def time_unit(seconds): +def time_unit(seconds: int) -> typing.Tuple[int, str]: since = None unit = None if seconds >= TIME_WEEK: @@ -29,7 +28,7 @@ def time_unit(seconds): since = int(since) if since > 1: unit = "%ss" % unit # pluralise the unit - return [since, unit] + return (since, unit) REGEX_PRETTYTIME = re.compile("\d+[wdhms]", re.I) @@ -38,7 +37,7 @@ SECONDS_HOURS = SECONDS_MINUTES*60 SECONDS_DAYS = SECONDS_HOURS*24 SECONDS_WEEKS = SECONDS_DAYS*7 -def from_pretty_time(pretty_time): +def from_pretty_time(pretty_time: str) -> typing.Optional[int]: seconds = 0 for match in re.findall(REGEX_PRETTYTIME, pretty_time): number, unit = int(match[:-1]), match[-1].lower() @@ -54,12 +53,14 @@ def from_pretty_time(pretty_time): if seconds > 0: return seconds +UNIT_MINIMUM = 6 UNIT_SECOND = 5 UNIT_MINUTE = 4 UNIT_HOUR = 3 UNIT_DAY = 2 UNIT_WEEK = 1 -def to_pretty_time(total_seconds, minimum_unit=UNIT_SECOND, max_units=6): +def to_pretty_time(total_seconds: int, minimum_unit: int=UNIT_SECOND, + max_units: int=UNIT_MINIMUM) -> str: minutes, seconds = divmod(total_seconds, 60) hours, minutes = divmod(minutes, 60) days, hours = divmod(hours, 24) @@ -84,7 +85,7 @@ def to_pretty_time(total_seconds, minimum_unit=UNIT_SECOND, max_units=6): units += 1 return out -def parse_number(s): +def parse_number(s: str) -> str: try: decimal.Decimal(s) return s @@ -110,28 +111,18 @@ def parse_number(s): IS_TRUE = ["true", "yes", "on", "y"] IS_FALSE = ["false", "no", "off", "n"] -def bool_or_none(s): +def bool_or_none(s: str) -> typing.Optional[bool]: s = s.lower() if s in IS_TRUE: return True elif s in IS_FALSE: return False -def int_or_none(s): +def int_or_none(s: str) -> typing.Optional[int]: stripped_s = s.lstrip("0") if stripped_s.isdigit(): return int(stripped_s) -def get_closest_setting(event, setting, default=None): - server = event["server"] - if "channel" in event: - closest = event["channel"] - elif "target" in event and "is_channel" in event and event["is_channel"]: - closest = event["target"] - else: - closest = event["user"] - return closest.get_setting(setting, server.get_setting(setting, default)) - -def prevent_highlight(nickname): +def prevent_highlight(nickname: str) -> str: return nickname[0]+"\u200c"+nickname[1:] class EventError(Exception): @@ -139,86 +130,40 @@ class EventError(Exception): class EventsResultsError(EventError): def __init__(self): EventError.__init__(self, "Failed to load results") +class EventsNotEnoughArgsError(EventError): + def __init__(self, n): + EventError.__init__(self, "Not enough arguments (minimum %d)" % n) +class EventsUsageError(EventError): + def __init__(self, usage): + EventError.__init__(self, "Not enough arguments, usage: %s" % usage) -def _set_get_append(obj, setting, item): +def _set_get_append(obj: typing.Any, setting: str, item: typing.Any): if not hasattr(obj, setting): setattr(obj, setting, []) getattr(obj, setting).append(item) -def hook(event, **kwargs): +def hook(event: str, **kwargs): def _hook_func(func): - _set_get_append(func, ModuleManager.BITBOT_HOOKS_MAGIC, + _set_get_append(func, consts.BITBOT_HOOKS_MAGIC, {"event": event, "kwargs": kwargs}) return func return _hook_func -def export(setting, value): +def export(setting: str, value: typing.Any): def _export_func(module): - _set_get_append(module, ModuleManager.BITBOT_EXPORTS_MAGIC, + _set_get_append(module, consts.BITBOT_EXPORTS_MAGIC, {"setting": setting, "value": value}) return module return _export_func -COMMENT_TYPES = ["#", "//"] -def get_hashflags(filename): - hashflags = {} - with io.open(filename, mode="r", encoding="utf8") as f: - for line in f: - line = line.strip("\n") - found = False - for comment_type in COMMENT_TYPES: - if line.startswith(comment_type): - line = line.replace(comment_type, "", 1).lstrip() - found = True - break +TOP_10_CALLABLE = typing.Callable[[typing.Any], typing.Any] +def top_10(items: typing.List[typing.Any], + convert_key: TOP_10_CALLABLE=lambda x: x, + value_format: TOP_10_CALLABLE=lambda x: x): + top_10 = sorted(items.keys()) + top_10 = sorted(top_10, key=items.get, reverse=True)[:10] - if not found: - break - elif line.startswith("--"): - hashflag, sep, value = line[2:].partition(" ") - hashflags[hashflag] = value if sep else None - return hashflags.items() + top_10_items = [] + for key in top_10: + top_10_items.append("%s (%s)" % (convert_key(key), + value_format(items[key]))) -class Docstring(object): - def __init__(self, description, items, var_items): - self.description = description - self.items = items - self.var_items = var_items - -def parse_docstring(s): - description = "" - last_item = None - items = {} - var_items = {} - if s: - for line in s.split("\n"): - line = line.strip() - - if line: - if line[0] == ":": - key, _, value = line[1:].partition(": ") - last_item = key - - if key in var_items: - var_items[key].append(value) - elif key in items: - var_items[key] = [items.pop(key), value] - else: - items[key] = value - else: - if last_item: - items[last_item] += " %s" % line - else: - if description: - description += " " - description += line - return Docstring(description, items, var_items) - -def top_10(items, convert_key=lambda x: x, value_format=lambda x: x): - top_10 = sorted(items.keys()) - top_10 = sorted(top_10, key=items.get, reverse=True)[:10] - - top_10_items = [] - for key in top_10: - top_10_items.append("%s (%s)" % (convert_key(key), - value_format(items[key]))) - - return top_10_items + return top_10_items diff --git a/src/utils/consts.py b/src/utils/consts.py new file mode 100644 index 00000000..d2816509 --- /dev/null +++ b/src/utils/consts.py @@ -0,0 +1,2 @@ +BITBOT_HOOKS_MAGIC = "__bitbot_hooks" +BITBOT_EXPORTS_MAGIC = "__bitbot_exports" diff --git a/src/utils/http.py b/src/utils/http.py index ddf88b2b..b949e9ff 100644 --- a/src/utils/http.py +++ b/src/utils/http.py @@ -1,4 +1,4 @@ -import re, signal, traceback, urllib.error, urllib.parse +import re, signal, traceback, typing, urllib.error, urllib.parse import json as _json import bs4, requests @@ -18,9 +18,10 @@ class HTTPParsingException(HTTPException): def throw_timeout(): raise HTTPTimeoutException() -def get_url(url, method="GET", get_params={}, post_data=None, headers={}, - json_data=None, code=False, json=False, soup=False, parser="lxml", - fallback_encoding="utf8"): +def get_url(url: str, method: str="GET", get_params: dict={}, + post_data: typing.Any=None, headers: dict={}, + json_data: typing.Any=None, code: bool=False, json: bool=False, + soup: bool=False, parser: str="lxml", fallback_encoding: str="utf8"): if not urllib.parse.urlparse(url).scheme: url = "http://%s" % url @@ -66,6 +67,6 @@ def get_url(url, method="GET", get_params={}, post_data=None, headers={}, else: return data -def strip_html(s): +def strip_html(s: str) -> str: return bs4.BeautifulSoup(s, "lxml").get_text() diff --git a/src/utils/irc.py b/src/utils/irc.py index 792de7f3..3e7f8b76 100644 --- a/src/utils/irc.py +++ b/src/utils/irc.py @@ -1,4 +1,4 @@ -import string, re +import string, re, typing ASCII_UPPER = string.ascii_uppercase ASCII_LOWER = string.ascii_lowercase @@ -7,32 +7,36 @@ STRICT_RFC1459_LOWER = ASCII_LOWER+r'|{}' RFC1459_UPPER = STRICT_RFC1459_UPPER+"^" RFC1459_LOWER = STRICT_RFC1459_LOWER+"~" -def remove_colon(s): +def remove_colon(s: str) -> str: if s.startswith(":"): s = s[1:] return s +MULTI_REPLACE_ITERABLE = typing.Iterable[str] # case mapping lowercase/uppcase logic -def _multi_replace(s, chars1, chars2): +def _multi_replace(s: str, + chars1: typing.Iterable[str], + chars2: typing.Iterable[str]) -> str: for char1, char2 in zip(chars1, chars2): s = s.replace(char1, char2) return s -def lower(server, s): - if server.case_mapping == "ascii": +def lower(case_mapping: str, s: str) -> str: + if case_mapping == "ascii": return _multi_replace(s, ASCII_UPPER, ASCII_LOWER) - elif server.case_mapping == "rfc1459": + elif case_mapping == "rfc1459": return _multi_replace(s, RFC1459_UPPER, RFC1459_LOWER) - elif server.case_mapping == "strict-rfc1459": + elif case_mapping == "strict-rfc1459": return _multi_replace(s, STRICT_RFC1459_UPPER, STRICT_RFC1459_LOWER) else: - raise ValueError("unknown casemapping '%s'" % server.case_mapping) + raise ValueError("unknown casemapping '%s'" % case_mapping) # compare a string while respecting case mapping -def equals(server, s1, s2): - return lower(server, s1) == lower(server, s2) +def equals(case_mapping: str, s1: str, s2: str) -> bool: + return lower(case_mapping, s1) == lower(case_mapping, s2) class IRCHostmask(object): - def __init__(self, nickname, username, hostname, hostmask): + def __init__(self, nickname: str, username: str, hostname: str, + hostmask: str): self.nickname = nickname self.username = username self.hostname = hostname @@ -42,24 +46,24 @@ class IRCHostmask(object): def __str__(self): return self.hostmask -def seperate_hostmask(hostmask): +def seperate_hostmask(hostmask: str) -> IRCHostmask: hostmask = remove_colon(hostmask) nickname, _, username = hostmask.partition("!") username, _, hostname = username.partition("@") return IRCHostmask(nickname, username, hostname, hostmask) - class IRCLine(object): - def __init__(self, tags, prefix, command, args, arbitrary, last, server): + def __init__(self, tags: dict, prefix: str, command: str, + args: typing.List[str], arbitrary: typing.Optional[str], + last: str): self.tags = tags self.prefix = prefix self.command = command self.args = args self.arbitrary = arbitrary self.last = last - self.server = server -def parse_line(server, line): +def parse_line(line: str) -> IRCLine: tags = {} prefix = None command = None @@ -81,7 +85,7 @@ def parse_line(server, line): args = line.split(" ") last = arbitrary or args[-1] - return IRCLine(tags, prefix, command, args, arbitrary, last, server) + return IRCLine(tags, prefix, command, args, arbitrary, last) COLOR_WHITE, COLOR_BLACK, COLOR_BLUE, COLOR_GREEN = 0, 1, 2, 3 COLOR_RED, COLOR_BROWN, COLOR_PURPLE, COLOR_ORANGE = 4, 5, 6, 7 @@ -94,20 +98,20 @@ FONT_BOLD, FONT_ITALIC, FONT_UNDERLINE, FONT_INVERT = ("\x02", "\x1D", FONT_COLOR, FONT_RESET = "\x03", "\x0F" REGEX_COLOR = re.compile("%s\d\d(?:,\d\d)?" % FONT_COLOR) -def color(s, foreground, background=None): +def color(s: str, foreground: str, background: str=None) -> str: foreground = str(foreground).zfill(2) if background: background = str(background).zfill(2) return "%s%s%s%s%s" % (FONT_COLOR, foreground, "" if not background else ",%s" % background, s, FONT_COLOR) -def bold(s): +def bold(s: str) -> str: return "%s%s%s" % (FONT_BOLD, s, FONT_BOLD) -def underline(s): +def underline(s: str) -> str: return "%s%s%s" % (FONT_UNDERLINE, s, FONT_UNDERLINE) -def strip_font(s): +def strip_font(s: str) -> str: s = s.replace(FONT_BOLD, "") s = s.replace(FONT_ITALIC, "") s = REGEX_COLOR.sub("", s) diff --git a/src/utils/parse.py b/src/utils/parse.py new file mode 100644 index 00000000..03a585c2 --- /dev/null +++ b/src/utils/parse.py @@ -0,0 +1,57 @@ +import io, typing + +COMMENT_TYPES = ["#", "//"] +def hashflags(filename: str) -> typing.List[typing.Tuple[str, str]]: + hashflags = {} + with io.open(filename, mode="r", encoding="utf8") as f: + for line in f: + line = line.strip("\n") + found = False + for comment_type in COMMENT_TYPES: + if line.startswith(comment_type): + line = line.replace(comment_type, "", 1).lstrip() + found = True + break + + if not found: + break + elif line.startswith("--"): + hashflag, sep, value = line[2:].partition(" ") + hashflags[hashflag] = value if sep else None + return list(hashflags.items()) + +class Docstring(object): + def __init__(self, description: str, items: dict, var_items: dict): + self.description = description + self.items = items + self.var_items = var_items + +def docstring(s: str) -> Docstring: + description = "" + last_item = None + items = {} + var_items = {} + if s: + for line in s.split("\n"): + line = line.strip() + + if line: + if line[0] == ":": + key, _, value = line[1:].partition(": ") + last_item = key + + if key in var_items: + var_items[key].append(value) + elif key in items: + var_items[key] = [items.pop(key), value] + else: + items[key] = value + else: + if last_item: + items[last_item] += " %s" % line + else: + if description: + description += " " + description += line + return Docstring(description, items, var_items) +