Add type/return hints throughout src/ and, in doing so, fix some cyclical
references.
This commit is contained in:
parent
705daaf9bb
commit
e07553c362
22 changed files with 605 additions and 516 deletions
|
@ -27,7 +27,7 @@ class Module(ModuleManager.BaseModule):
|
|||
if not event["user"].last_karma or (time.time()-event["user"
|
||||
].last_karma) >= KARMA_DELAY_SECONDS:
|
||||
target = match.group(1).strip()
|
||||
if utils.irc.lower(event["server"], target
|
||||
if utils.irc.lower(event["server"].case_mapping, target
|
||||
) == event["user"].name:
|
||||
if verbose:
|
||||
self.events.on("send.stderr").call(
|
||||
|
|
|
@ -542,7 +542,8 @@ class Module(ModuleManager.BaseModule):
|
|||
# we need a registered nickname for this channel
|
||||
@utils.hook("raw.477", default_event=True)
|
||||
def handle_477(self, event):
|
||||
channel_name = utils.irc.lower(event["server"], event["args"][1])
|
||||
channel_name = utils.irc.lower(event["server"].case_mapping,
|
||||
event["args"][1])
|
||||
if channel_name in event["server"]:
|
||||
key = event["server"].attempted_join[channel_name]
|
||||
self.timers.add("rejoin", 5, channel_name=channe_name, key=key,
|
||||
|
|
|
@ -11,12 +11,16 @@ REGEX_SED = re.compile("^s/")
|
|||
"help": "Disable/Enable sed only looking at the messages sent by the user",
|
||||
"validate": utils.bool_or_none})
|
||||
class Module(ModuleManager.BaseModule):
|
||||
def _closest_setting(self, event, setting, default):
|
||||
return event["channel"].get_setting(setting,
|
||||
event["server"].get_setting(setting, default))
|
||||
|
||||
@utils.hook("received.message.channel")
|
||||
def channel_message(self, event):
|
||||
sed_split = re.split(REGEX_SPLIT, event["message"], 3)
|
||||
if event["message"].startswith("s/") and len(sed_split) > 2:
|
||||
if event["action"] or not utils.get_closest_setting(
|
||||
event, "sed", False):
|
||||
if event["action"] or not self._closest_setting(event, "sed",
|
||||
False):
|
||||
return
|
||||
|
||||
regex_flags = 0
|
||||
|
@ -48,9 +52,8 @@ class Module(ModuleManager.BaseModule):
|
|||
return
|
||||
replace = sed_split[2].replace("\\/", "/")
|
||||
|
||||
for_user = event["user"].nickname if utils.get_closest_setting(
|
||||
event, "sed-sender-only", False
|
||||
) else None
|
||||
for_user = event["user"].nickname if self._closest_setting(event,
|
||||
"sed-sender-only", False) else None
|
||||
line = event["channel"].buffer.find(pattern, from_self=False,
|
||||
for_user=for_user, not_pattern=REGEX_SED)
|
||||
if line:
|
||||
|
|
18
src/Cache.py
18
src/Cache.py
|
@ -1,21 +1,21 @@
|
|||
import time, uuid
|
||||
import time, typing, uuid
|
||||
|
||||
class Cache(object):
|
||||
def __init__(self):
|
||||
self._items = {}
|
||||
self._item_to_id = {}
|
||||
|
||||
def cache(self, item):
|
||||
def cache(self, item: typing.Any) -> str:
|
||||
return self._cache(item, None)
|
||||
def temporary_cache(self, item, timeout):
|
||||
def temporary_cache(self, item: typing.Any, timeout: float)-> str:
|
||||
return self._cache(item, timeout)
|
||||
def _cache(self, item, timeout):
|
||||
def _cache(self, item: typing.Any, timeout: float) -> str:
|
||||
id = str(uuid.uuid4())
|
||||
self._items[id] = [item, time.monotonic()+timeout]
|
||||
self._item_to_id[item] = id
|
||||
return id
|
||||
|
||||
def next_expiration(self):
|
||||
def next_expiration(self) -> float:
|
||||
expirations = [self._items[id][1] for id in self._items]
|
||||
expirations = list(filter(None, expirations))
|
||||
if not expirations:
|
||||
|
@ -35,17 +35,17 @@ class Cache(object):
|
|||
del self._items[id]
|
||||
del self._item_to_id[item]
|
||||
|
||||
def has_item(self, item):
|
||||
def has_item(self, item: typing.Any) -> bool:
|
||||
return item in self._item_to_id
|
||||
|
||||
def get(self, id):
|
||||
def get(self, id: str) -> typing.Any:
|
||||
item, expiration = self._items[id]
|
||||
return item
|
||||
|
||||
def get_expiration(self, item):
|
||||
def get_expiration(self, item: typing.Any) -> float:
|
||||
id = self._item_to_id[item]
|
||||
item, expiration = self._items[id]
|
||||
return expiration
|
||||
def until_expiration(self, item):
|
||||
def until_expiration(self, item: typing.Any) -> float:
|
||||
expiration = self.get_expiration(item)
|
||||
return expiration-time.monotonic()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import configparser, os
|
||||
import configparser, os, typing
|
||||
|
||||
class Config(object):
|
||||
def __init__(self, location):
|
||||
def __init__(self, location: str):
|
||||
self.location = location
|
||||
self._config = {}
|
||||
self.load()
|
||||
|
@ -13,10 +13,10 @@ class Config(object):
|
|||
parser.read_string(config_file.read())
|
||||
self._config = dict(parser["bot"].items())
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> typing.Any:
|
||||
return self._config[key]
|
||||
def get(self, key, default=None):
|
||||
def get(self, key: str, default: typing.Any=None) -> typing.Any:
|
||||
return self._config.get(key, default)
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self._config
|
||||
|
||||
|
|
101
src/Database.py
101
src/Database.py
|
@ -1,12 +1,14 @@
|
|||
import json, os, sqlite3, threading, time
|
||||
import json, os, sqlite3, threading, time, typing
|
||||
from src import Logging
|
||||
|
||||
class Table(object):
|
||||
def __init__(self, database):
|
||||
self.database = database
|
||||
|
||||
class Servers(Table):
|
||||
def add(self, alias, hostname, port, password, ipv4, tls, bindhost,
|
||||
nickname, username=None, realname=None):
|
||||
def add(self, alias: str, hostname: str, port: int, password: str,
|
||||
ipv4: bool, tls: bool, bindhost: str,
|
||||
nickname: str, username: str=None, realname: str=None):
|
||||
username = username or nickname
|
||||
realname = realname or nickname
|
||||
self.database.execute(
|
||||
|
@ -18,7 +20,7 @@ class Servers(Table):
|
|||
def get_all(self):
|
||||
return self.database.execute_fetchall(
|
||||
"SELECT server_id, alias FROM servers")
|
||||
def get(self, id):
|
||||
def get(self, id: int):
|
||||
return self.database.execute_fetchone(
|
||||
"""SELECT server_id, alias, hostname, port, password, ipv4,
|
||||
tls, bindhost, nickname, username, realname FROM servers WHERE
|
||||
|
@ -26,46 +28,46 @@ class Servers(Table):
|
|||
[id])
|
||||
|
||||
class Channels(Table):
|
||||
def add(self, server_id, name):
|
||||
def add(self, server_id: int, name: str):
|
||||
self.database.execute("""INSERT OR IGNORE INTO channels
|
||||
(server_id, name) VALUES (?, ?)""",
|
||||
[server_id, name.lower()])
|
||||
def delete(self, channel_id):
|
||||
def delete(self, channel_id: int):
|
||||
self.database.execute("DELETE FROM channels WHERE channel_id=?",
|
||||
[channel_id])
|
||||
def get_id(self, server_id, name):
|
||||
def get_id(self, server_id: int, name: str):
|
||||
value = self.database.execute_fetchone("""SELECT channel_id FROM
|
||||
channels WHERE server_id=? AND name=?""",
|
||||
[server_id, name.lower()])
|
||||
return value if value == None else value[0]
|
||||
|
||||
class Users(Table):
|
||||
def add(self, server_id, nickname):
|
||||
def add(self, server_id: int, nickname: str):
|
||||
self.database.execute("""INSERT OR IGNORE INTO users
|
||||
(server_id, nickname) VALUES (?, ?)""",
|
||||
[server_id, nickname.lower()])
|
||||
def delete(self, user_id):
|
||||
def delete(self, user_id: int):
|
||||
self.database.execute("DELETE FROM users WHERE user_id=?",
|
||||
[user_id])
|
||||
def get_id(self, server_id, nickname):
|
||||
def get_id(self, server_id: int, nickname: str):
|
||||
value = self.database.execute_fetchone("""SELECT user_id FROM
|
||||
users WHERE server_id=? and nickname=?""",
|
||||
[server_id, nickname.lower()])
|
||||
return value if value == None else value[0]
|
||||
|
||||
class BotSettings(Table):
|
||||
def set(self, setting, value):
|
||||
def set(self, setting: str, value: typing.Any):
|
||||
self.database.execute(
|
||||
"INSERT OR REPLACE INTO bot_settings VALUES (?, ?)",
|
||||
[setting.lower(), json.dumps(value)])
|
||||
def get(self, setting, default=None):
|
||||
def get(self, setting: str, default: typing.Any=None):
|
||||
value = self.database.execute_fetchone(
|
||||
"SELECT value FROM bot_settings WHERE setting=?",
|
||||
[setting.lower()])
|
||||
if value:
|
||||
return json.loads(value[0])
|
||||
return default
|
||||
def find(self, pattern, default=[]):
|
||||
def find(self, pattern: str, default: typing.Any=[]):
|
||||
values = self.database.execute_fetchall(
|
||||
"SELECT setting, value FROM bot_settings WHERE setting LIKE ?",
|
||||
[pattern.lower()])
|
||||
|
@ -74,19 +76,19 @@ class BotSettings(Table):
|
|||
values[i] = value[0], json.loads(value[1])
|
||||
return values
|
||||
return default
|
||||
def find_prefix(self, prefix, default=[]):
|
||||
def find_prefix(self, prefix: str, default: typing.Any=[]):
|
||||
return self.find("%s%%" % prefix, default)
|
||||
def delete(self, setting):
|
||||
def delete(self, setting: str):
|
||||
self.database.execute(
|
||||
"DELETE FROM bot_settings WHERE setting=?",
|
||||
[setting.lower()])
|
||||
|
||||
class ServerSettings(Table):
|
||||
def set(self, server_id, setting, value):
|
||||
def set(self, server_id: int, setting: str, value: typing.Any):
|
||||
self.database.execute(
|
||||
"INSERT OR REPLACE INTO server_settings VALUES (?, ?, ?)",
|
||||
[server_id, setting.lower(), json.dumps(value)])
|
||||
def get(self, server_id, setting, default=None):
|
||||
def get(self, server_id: int, setting: str, default: typing.Any=None):
|
||||
value = self.database.execute_fetchone(
|
||||
"""SELECT value FROM server_settings WHERE
|
||||
server_id=? AND setting=?""",
|
||||
|
@ -94,7 +96,7 @@ class ServerSettings(Table):
|
|||
if value:
|
||||
return json.loads(value[0])
|
||||
return default
|
||||
def find(self, server_id, pattern, default=[]):
|
||||
def find(self, server_id: int, pattern: str, default: typing.Any=[]):
|
||||
values = self.database.execute_fetchall(
|
||||
"""SELECT setting, value FROM server_settings WHERE
|
||||
server_id=? AND setting LIKE ?""",
|
||||
|
@ -104,26 +106,26 @@ class ServerSettings(Table):
|
|||
values[i] = value[0], json.loads(value[1])
|
||||
return values
|
||||
return default
|
||||
def find_prefix(self, server_id, prefix, default=[]):
|
||||
def find_prefix(self, server_id: int, prefix: str, default: typing.Any=[]):
|
||||
return self.find_server_settings(server_id, "%s%%" % prefix, default)
|
||||
def delete(self, server_id, setting):
|
||||
def delete(self, server_id: int, setting: str):
|
||||
self.database.execute(
|
||||
"DELETE FROM server_settings WHERE server_id=? AND setting=?",
|
||||
[server_id, setting.lower()])
|
||||
|
||||
class ChannelSettings(Table):
|
||||
def set(self, channel_id, setting, value):
|
||||
def set(self, channel_id: int, setting: str, value: typing.Any):
|
||||
self.database.execute(
|
||||
"INSERT OR REPLACE INTO channel_settings VALUES (?, ?, ?)",
|
||||
[channel_id, setting.lower(), json.dumps(value)])
|
||||
def get(self, channel_id, setting, default=None):
|
||||
def get(self, channel_id: int, setting: str, default: typing.Any=None):
|
||||
value = self.database.execute_fetchone(
|
||||
"""SELECT value FROM channel_settings WHERE
|
||||
channel_id=? AND setting=?""", [channel_id, setting.lower()])
|
||||
if value:
|
||||
return json.loads(value[0])
|
||||
return default
|
||||
def find(self, channel_id, pattern, default=[]):
|
||||
def find(self, channel_id: int, pattern: str, default: typing.Any=[]):
|
||||
values = self.database.execute_fetchall(
|
||||
"""SELECT setting, value FROM channel_settings WHERE
|
||||
channel_id=? setting LIKE '?'""", [channel_id, pattern.lower()])
|
||||
|
@ -132,15 +134,15 @@ class ChannelSettings(Table):
|
|||
values[i] = value[0], json.loads(value[1])
|
||||
return values
|
||||
return default
|
||||
def find_prefix(self, channel_id, prefix, default=[]):
|
||||
def find_prefix(self, channel_id: int, prefix: str, default: typing.Any=[]):
|
||||
return self.find_channel_settings(channel_id, "%s%%" % prefix,
|
||||
default)
|
||||
def delete(self, channel_id, setting):
|
||||
def delete(self, channel_id: int, setting: str):
|
||||
self.database.execute(
|
||||
"""DELETE FROM channel_settings WHERE channel_id=?
|
||||
AND setting=?""", [channel_id, setting.lower()])
|
||||
|
||||
def find_by_setting(self, setting, default=[]):
|
||||
def find_by_setting(self, setting: str, default: typing.Any=[]):
|
||||
values = self.database.execute_fetchall(
|
||||
"""SELECT channels.server_id, channels.name,
|
||||
channel_settings.value FROM channel_settings
|
||||
|
@ -154,18 +156,19 @@ class ChannelSettings(Table):
|
|||
return default
|
||||
|
||||
class UserSettings(Table):
|
||||
def set(self, user_id, setting, value):
|
||||
def set(self, user_id: int, setting: str, value: typing.Any):
|
||||
self.database.execute(
|
||||
"INSERT OR REPLACE INTO user_settings VALUES (?, ?, ?)",
|
||||
[user_id, setting.lower(), json.dumps(value)])
|
||||
def get(self, user_id, setting, default=None):
|
||||
def get(self, user_id: int, setting: str, default: typing.Any=None):
|
||||
value = self.database.execute_fetchone(
|
||||
"""SELECT value FROM user_settings WHERE
|
||||
user_id=? and setting=?""", [user_id, setting.lower()])
|
||||
if value:
|
||||
return json.loads(value[0])
|
||||
return default
|
||||
def find_all_by_setting(self, server_id, setting, default=[]):
|
||||
def find_all_by_setting(self, server_id: int, setting: str,
|
||||
default: typing.Any=[]):
|
||||
values = self.database.execute_fetchall(
|
||||
"""SELECT users.nickname, user_settings.value FROM
|
||||
user_settings INNER JOIN users ON
|
||||
|
@ -177,7 +180,7 @@ class UserSettings(Table):
|
|||
values[i] = value[0], json.loads(value[1])
|
||||
return values
|
||||
return default
|
||||
def find(self, user_id, pattern, default=[]):
|
||||
def find(self, user_id: int, pattern: str, default: typing.Any=[]):
|
||||
values = self.database.execute(
|
||||
"""SELECT setting, value FROM user_settings WHERE
|
||||
user_id=? AND setting LIKE '?'""", [user_id, pattern.lower()])
|
||||
|
@ -186,20 +189,22 @@ class UserSettings(Table):
|
|||
values[i] = value[0], json.loads(value[1])
|
||||
return values
|
||||
return default
|
||||
def find_prefix(self, user_id, prefix, default=[]):
|
||||
def find_prefix(self, user_id: int, prefix: str, default: typing.Any=[]):
|
||||
return self.find_user_settings(user_id, "%s%%" % prefix, default)
|
||||
def delete(self, user_id, setting):
|
||||
def delete(self, user_id: int, setting: str):
|
||||
self.database.execute(
|
||||
"""DELETE FROM user_settings WHERE
|
||||
user_id=? AND setting=?""", [user_id, setting.lower()])
|
||||
|
||||
class UserChannelSettings(Table):
|
||||
def set(self, user_id, channel_id, setting, value):
|
||||
def set(self, user_id: int, channel_id: int, setting: str,
|
||||
value: typing.Any):
|
||||
self.database.execute(
|
||||
"""INSERT OR REPLACE INTO user_channel_settings VALUES
|
||||
(?, ?, ?, ?)""",
|
||||
[user_id, channel_id, setting.lower(), json.dumps(value)])
|
||||
def get(self, user_id, channel_id, setting, default=None):
|
||||
def get(self, user_id: int, channel_id: int, setting: str,
|
||||
default: typing.Any=None):
|
||||
value = self.database.execute_fetchone(
|
||||
"""SELECT value FROM user_channel_settings WHERE
|
||||
user_id=? AND channel_id=? AND setting=?""",
|
||||
|
@ -207,7 +212,8 @@ class UserChannelSettings(Table):
|
|||
if value:
|
||||
return json.loads(value[0])
|
||||
return default
|
||||
def find(self, user_id, channel_id, pattern, default=[]):
|
||||
def find(self, user_id: int, channel_id: int, pattern: str,
|
||||
default: typing.Any=[]):
|
||||
values = self.database.execute_fetchall(
|
||||
"""SELECT setting, value FROM user_channel_settings WHERE
|
||||
user_id=? AND channel_id=? AND setting LIKE '?'""",
|
||||
|
@ -217,10 +223,12 @@ class UserChannelSettings(Table):
|
|||
values[i] = value[0], json.loads(value[1])
|
||||
return values
|
||||
return default
|
||||
def find_prefix(self, user_id, channel_id, prefix, default=[]):
|
||||
def find_prefix(self, user_id: int, channel_id: int, prefix: str,
|
||||
default: typing.Any=[]):
|
||||
return self.find_user_settings(user_id, channel_id, "%s%%" % prefix,
|
||||
default)
|
||||
def find_by_setting(self, user_id, setting, default=[]):
|
||||
def find_by_setting(self, user_id: int, setting: str,
|
||||
default: typing.Any=[]):
|
||||
values = self.database.execute_fetchall(
|
||||
"""SELECT channels.name, user_channel_settings.value FROM
|
||||
user_channel_settings INNER JOIN channels ON
|
||||
|
@ -232,7 +240,8 @@ class UserChannelSettings(Table):
|
|||
values[i] = value[0], json.loads(value[1])
|
||||
return values
|
||||
return default
|
||||
def find_all_by_setting(self, server_id, setting, default=[]):
|
||||
def find_all_by_setting(self, server_id: int, setting: str,
|
||||
default: typing.Any=[]):
|
||||
values = self.database.execute_fetchall(
|
||||
"""SELECT channels.name, users.nickname,
|
||||
user_channel_settings.value FROM
|
||||
|
@ -246,14 +255,14 @@ class UserChannelSettings(Table):
|
|||
values[i] = value[0], value[1], json.loads(value[2])
|
||||
return values
|
||||
return default
|
||||
def delete(self, user_id, channel_id, setting):
|
||||
def delete(self, user_id: int, channel_id: int, setting: str):
|
||||
self.database.execute(
|
||||
"""DELETE FROM user_channel_settings WHERE
|
||||
user_id=? AND channel_id=? AND setting=?""",
|
||||
[user_id, channel_id, setting.lower()])
|
||||
|
||||
class Database(object):
|
||||
def __init__(self, log, location):
|
||||
def __init__(self, log: "Logging.Log", location: str):
|
||||
self.log = log
|
||||
self.location = location
|
||||
self.database = sqlite3.connect(self.location,
|
||||
|
@ -284,7 +293,9 @@ class Database(object):
|
|||
self._cursor = self.database.cursor()
|
||||
return self._cursor
|
||||
|
||||
def _execute_fetch(self, query, fetch_func, params=[]):
|
||||
def _execute_fetch(self, query: str,
|
||||
fetch_func: typing.Callable[[sqlite3.Cursor], typing.Any],
|
||||
params: typing.List=[]):
|
||||
printable_query = " ".join(query.split())
|
||||
self.log.trace("executing query: \"%s\" (params: %s)",
|
||||
[printable_query, params])
|
||||
|
@ -299,16 +310,16 @@ class Database(object):
|
|||
self.log.trace("executed in %fms", [total_milliseconds])
|
||||
|
||||
return value
|
||||
def execute_fetchall(self, query, params=[]):
|
||||
def execute_fetchall(self, query: str, params: typing.List=[]):
|
||||
return self._execute_fetch(query,
|
||||
lambda cursor: cursor.fetchall(), params)
|
||||
def execute_fetchone(self, query, params=[]):
|
||||
def execute_fetchone(self, query: str, params: typing.List=[]):
|
||||
return self._execute_fetch(query,
|
||||
lambda cursor: cursor.fetchone(), params)
|
||||
def execute(self, query, params=[]):
|
||||
def execute(self, query: str, params: typing.List=[]):
|
||||
return self._execute_fetch(query, lambda cursor: None, params)
|
||||
|
||||
def has_table(self, table_name):
|
||||
def has_table(self, table_name: str):
|
||||
result = self.execute_fetchone("""SELECT COUNT(*) FROM
|
||||
sqlite_master WHERE type='table' AND name=?""",
|
||||
[table_name])
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import itertools, time, traceback
|
||||
from src import utils
|
||||
import itertools, time, traceback, typing
|
||||
from src import Logging, utils
|
||||
|
||||
PRIORITY_URGENT = 0
|
||||
PRIORITY_HIGH = 1
|
||||
|
@ -11,94 +11,39 @@ DEFAULT_PRIORITY = PRIORITY_MEDIUM
|
|||
DEFAULT_EVENT_DELIMITER = "."
|
||||
DEFAULT_MULTI_DELIMITER = "|"
|
||||
|
||||
CALLBACK_TYPE = typing.Callable[["Event"], typing.Any]
|
||||
|
||||
class Event(object):
|
||||
def __init__(self, name, **kwargs):
|
||||
def __init__(self, name: str, **kwargs):
|
||||
self.name = name
|
||||
self.kwargs = kwargs
|
||||
self.eaten = False
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> typing.Any:
|
||||
return self.kwargs[key]
|
||||
def get(self, key, default=None):
|
||||
def get(self, key: str, default=None) -> typing.Any:
|
||||
return self.kwargs.get(key, default)
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self.kwargs
|
||||
def eat(self):
|
||||
self.eaten = True
|
||||
|
||||
class EventCallback(object):
|
||||
def __init__(self, function, priority, kwargs):
|
||||
def __init__(self, function: CALLBACK_TYPE, priority: int, kwargs: dict):
|
||||
self.function = function
|
||||
self.priority = priority
|
||||
self.kwargs = kwargs
|
||||
self.docstring = utils.parse_docstring(function.__doc__)
|
||||
self.docstring = utils.parse.docstring(function.__doc__)
|
||||
|
||||
def call(self, event):
|
||||
def call(self, event: Event) -> typing.Any:
|
||||
return self.function(event)
|
||||
|
||||
def get_kwarg(self, name, default=None):
|
||||
def get_kwarg(self, name: str, default=None) -> typing.Any:
|
||||
item = self.kwargs.get(name, default)
|
||||
return item or self.docstring.items.get(name, default)
|
||||
|
||||
class MultipleEventHook(object):
|
||||
def __init__(self):
|
||||
self._event_hooks = set([])
|
||||
def _add(self, event_hook):
|
||||
self._event_hooks.add(event_hook)
|
||||
|
||||
def hook(self, function, **kwargs):
|
||||
for event_hook in self._event_hooks:
|
||||
event_hook.hook(function, **kwargs)
|
||||
|
||||
def call_limited(self, maximum, **kwargs):
|
||||
returns = []
|
||||
for event_hook in self._event_hooks:
|
||||
returns.append(event_hook.call_limited(maximum, **kwargs))
|
||||
return returns
|
||||
def call(self, **kwargs):
|
||||
returns = []
|
||||
for event_hook in self._event_hooks:
|
||||
returns.append(event_hook.call(**kwargs))
|
||||
return returns
|
||||
|
||||
class EventHookContext(object):
|
||||
def __init__(self, parent, context):
|
||||
self._parent = parent
|
||||
self.context = context
|
||||
def hook(self, function, priority=DEFAULT_PRIORITY, replay=False,
|
||||
**kwargs):
|
||||
return self._parent._context_hook(self.context, function, priority,
|
||||
replay, kwargs)
|
||||
def unhook(self, callback):
|
||||
self._parent.unhook(callback)
|
||||
|
||||
def on(self, subevent, *extra_subevents,
|
||||
delimiter=DEFAULT_EVENT_DELIMITER):
|
||||
return self._parent._context_on(self.context, subevent,
|
||||
extra_subevents, delimiter)
|
||||
|
||||
def call_for_result(self, default=None, **kwargs):
|
||||
return self._parent.call_for_result(default, **kwargs)
|
||||
def assure_call(self, **kwargs):
|
||||
self._parent.assure_call(**kwargs)
|
||||
def call(self, **kwargs):
|
||||
return self._parent.call(**kwargs)
|
||||
def call_limited(self, maximum, **kwargs):
|
||||
return self._parent.call_limited(maximum, **kwargs)
|
||||
|
||||
def call_unsafe_for_result(self, default=None, **kwargs):
|
||||
return self._parent.call_unsafe_for_result(default, **kwargs)
|
||||
def call_unsafe(self, **kwargs):
|
||||
return self._parent.call_unsafe(**kwargs)
|
||||
def call_unsafe_limited(self, maximum, **kwargs):
|
||||
return self._parent.call_unsafe_limited(maximum, **kwargs)
|
||||
|
||||
def get_hooks(self):
|
||||
return self._parent.get_hooks()
|
||||
def get_children(self):
|
||||
return self._parent.get_children()
|
||||
|
||||
class EventHook(object):
|
||||
def __init__(self, log, name=None, parent=None):
|
||||
def __init__(self, log: Logging.Log, name: str = None,
|
||||
parent: "EventHook" = None):
|
||||
self.log = log
|
||||
self.name = name
|
||||
self.parent = parent
|
||||
|
@ -107,10 +52,10 @@ class EventHook(object):
|
|||
self._stored_events = []
|
||||
self._context_hooks = {}
|
||||
|
||||
def _make_event(self, kwargs):
|
||||
def _make_event(self, kwargs: dict) -> Event:
|
||||
return Event(self._get_path(), **kwargs)
|
||||
|
||||
def _get_path(self):
|
||||
def _get_path(self) -> str:
|
||||
path = []
|
||||
parent = self
|
||||
while not parent == None and not parent.name == None:
|
||||
|
@ -118,15 +63,17 @@ class EventHook(object):
|
|||
parent = parent.parent
|
||||
return DEFAULT_EVENT_DELIMITER.join(path[::-1])
|
||||
|
||||
def new_context(self, context):
|
||||
def new_context(self, context: str) -> "EventHookContext":
|
||||
return EventHookContext(self, context)
|
||||
|
||||
def hook(self, function, priority=DEFAULT_PRIORITY, replay=False,
|
||||
**kwargs):
|
||||
def hook(self, function: CALLBACK_TYPE, priority: int = DEFAULT_PRIORITY,
|
||||
replay: bool = False, **kwargs) -> EventCallback:
|
||||
return self._hook(function, None, priority, replay, kwargs)
|
||||
def _context_hook(self, context, function, priority, replay, kwargs):
|
||||
def _context_hook(self, context: str, function: CALLBACK_TYPE,
|
||||
priority: int, replay: bool, kwargs: dict) -> EventCallback:
|
||||
return self._hook(function, context, priority, replay, kwargs)
|
||||
def _hook(self, function, context, priority, replay, kwargs):
|
||||
def _hook(self, function: CALLBACK_TYPE, context: str, priority: int,
|
||||
replay: bool, kwargs: dict) -> EventCallback:
|
||||
callback = EventCallback(function, priority, kwargs)
|
||||
|
||||
if context == None:
|
||||
|
@ -142,7 +89,7 @@ class EventHook(object):
|
|||
self._stored_events = None
|
||||
return callback
|
||||
|
||||
def unhook(self, callback):
|
||||
def unhook(self, callback: "EventHook"):
|
||||
if callback in self._hooks:
|
||||
self._hooks.remove(callback)
|
||||
|
||||
|
@ -155,7 +102,8 @@ class EventHook(object):
|
|||
for context in empty:
|
||||
del self._context_hooks[context]
|
||||
|
||||
def _make_multiple_hook(self, source, context, events):
|
||||
def _make_multiple_hook(self, source: "EventHook", context: str,
|
||||
events: typing.List[str]) -> "MultipleEventHook":
|
||||
multiple_event_hook = MultipleEventHook()
|
||||
for event in events:
|
||||
event_hook = source.get_child(event)
|
||||
|
@ -164,13 +112,15 @@ class EventHook(object):
|
|||
multiple_event_hook._add(event_hook)
|
||||
return multiple_event_hook
|
||||
|
||||
def on(self, subevent, *extra_subevents,
|
||||
delimiter=DEFAULT_EVENT_DELIMITER):
|
||||
def on(self, subevent: str, *extra_subevents,
|
||||
delimiter: int = DEFAULT_EVENT_DELIMITER) -> "EventHook":
|
||||
return self._on(subevent, extra_subevents, None, delimiter)
|
||||
def _context_on(self, context, subevent, extra_subevents,
|
||||
delimiter=DEFAULT_EVENT_DELIMITER):
|
||||
def _context_on(self, context: str, subevent: str,
|
||||
extra_subevents: typing.List[str],
|
||||
delimiter: str = DEFAULT_EVENT_DELIMITER) -> "EventHook":
|
||||
return self._on(subevent, extra_subevents, context, delimiter)
|
||||
def _on(self, subevent, extra_subevents, context, delimiter):
|
||||
def _on(self, subevent: str, extra_subevents: typing.List[str],
|
||||
context: str, delimiter: str) -> "EventHook":
|
||||
if delimiter in subevent:
|
||||
event_chain = subevent.split(delimiter)
|
||||
event_obj = self
|
||||
|
@ -193,26 +143,28 @@ class EventHook(object):
|
|||
child = child.new_context(context)
|
||||
return child
|
||||
|
||||
def call_for_result(self, default=None, **kwargs):
|
||||
def call_for_result(self, default=None, **kwargs) -> typing.Any:
|
||||
return (self.call_limited(1, **kwargs) or [default])[0]
|
||||
def assure_call(self, **kwargs):
|
||||
if not self._stored_events == None:
|
||||
self._stored_events.append(kwargs)
|
||||
else:
|
||||
self._call(kwargs, True, None)
|
||||
def call(self, **kwargs):
|
||||
def call(self, **kwargs) -> typing.List[typing.Any]:
|
||||
return self._call(kwargs, True, None)
|
||||
def call_limited(self, maximum, **kwargs):
|
||||
def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]:
|
||||
return self._call(kwargs, True, None)
|
||||
|
||||
def call_unsafe_for_result(self, default=None, **kwargs):
|
||||
def call_unsafe_for_result(self, default=None, **kwargs) -> typing.Any:
|
||||
return (self.call_unsafe_limited(1, **kwargs) or [default])[0]
|
||||
def call_unsafe(self, **kwargs):
|
||||
def call_unsafe(self, **kwargs) -> typing.List[typing.Any]:
|
||||
return self._call(kwargs, False, None)
|
||||
def call_unsafe_limited(self, maximum, **kwargs):
|
||||
def call_unsafe_limited(self, maximum: int, **kwargs
|
||||
) -> typing.List[typing.Any]:
|
||||
return self._call(kwargs, False, maximum)
|
||||
|
||||
def _call(self, kwargs, safe, maximum):
|
||||
def _call(self, kwargs: dict, safe: bool, maximum: int
|
||||
) -> typing.List[typing.Any]:
|
||||
event_path = self._get_path()
|
||||
self.log.trace("calling event: \"%s\" (params: %s)",
|
||||
[event_path, kwargs])
|
||||
|
@ -240,13 +192,13 @@ class EventHook(object):
|
|||
|
||||
return returns
|
||||
|
||||
def get_child(self, child_name):
|
||||
def get_child(self, child_name: str) -> "EventHook":
|
||||
child_name_lower = child_name.lower()
|
||||
if not child_name_lower in self._children:
|
||||
self._children[child_name_lower] = EventHook(self.log,
|
||||
child_name_lower, self)
|
||||
return self._children[child_name_lower]
|
||||
def remove_child(self, child_name):
|
||||
def remove_child(self, child_name: str):
|
||||
child_name_lower = child_name.lower()
|
||||
if child_name_lower in self._children:
|
||||
del self._children[child_name_lower]
|
||||
|
@ -256,11 +208,11 @@ class EventHook(object):
|
|||
self.parent.remove_child(self.name)
|
||||
self.parent.check_purge()
|
||||
|
||||
def remove_context(self, context):
|
||||
def remove_context(self, context: str):
|
||||
del self._context_hooks[context]
|
||||
def has_context(self, context):
|
||||
def has_context(self, context: str) -> bool:
|
||||
return context in self._context_hooks
|
||||
def purge_context(self, context):
|
||||
def purge_context(self, context: str):
|
||||
if self.has_context(context):
|
||||
self.remove_context(context)
|
||||
|
||||
|
@ -268,10 +220,69 @@ class EventHook(object):
|
|||
child = self.get_child(child_name)
|
||||
child.purge_context(context)
|
||||
|
||||
def get_hooks(self):
|
||||
def get_hooks(self) -> typing.List[EventCallback]:
|
||||
return sorted(self._hooks + sum(self._context_hooks.values(), []),
|
||||
key=lambda e: e.priority)
|
||||
def get_children(self):
|
||||
def get_children(self) -> typing.List["EventHook"]:
|
||||
return list(self._children.keys())
|
||||
def is_empty(self):
|
||||
def is_empty(self) -> bool:
|
||||
return len(self.get_hooks() + self.get_children()) == 0
|
||||
|
||||
class MultipleEventHook(object):
|
||||
def __init__(self):
|
||||
self._event_hooks = set([])
|
||||
def _add(self, event_hook: EventHook):
|
||||
self._event_hooks.add(event_hook)
|
||||
|
||||
def hook(self, function: CALLBACK_TYPE, **kwargs):
|
||||
for event_hook in self._event_hooks:
|
||||
event_hook.hook(function, **kwargs)
|
||||
|
||||
def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]:
|
||||
returns = []
|
||||
for event_hook in self._event_hooks:
|
||||
returns.append(event_hook.call_limited(maximum, **kwargs))
|
||||
return returns
|
||||
def call(self, **kwargs) -> typing.List[typing.Any]:
|
||||
returns = []
|
||||
for event_hook in self._event_hooks:
|
||||
returns.append(event_hook.call(**kwargs))
|
||||
return returns
|
||||
|
||||
class EventHookContext(object):
|
||||
def __init__(self, parent, context):
|
||||
self._parent = parent
|
||||
self.context = context
|
||||
def hook(self, function: CALLBACK_TYPE, priority: int = DEFAULT_PRIORITY,
|
||||
replay: bool = False, **kwargs) -> EventCallback:
|
||||
return self._parent._context_hook(self.context, function, priority,
|
||||
replay, kwargs)
|
||||
def unhook(self, callback: EventCallback):
|
||||
self._parent.unhook(callback)
|
||||
|
||||
def on(self, subevent: str, *extra_subevents,
|
||||
delimiter: str = DEFAULT_EVENT_DELIMITER) -> EventHook:
|
||||
return self._parent._context_on(self.context, subevent,
|
||||
extra_subevents, delimiter)
|
||||
|
||||
def call_for_result(self, default=None, **kwargs) -> typing.Any:
|
||||
return self._parent.call_for_result(default, **kwargs)
|
||||
def assure_call(self, **kwargs):
|
||||
self._parent.assure_call(**kwargs)
|
||||
def call(self, **kwargs) -> typing.List[typing.Any]:
|
||||
return self._parent.call(**kwargs)
|
||||
def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]:
|
||||
return self._parent.call_limited(maximum, **kwargs)
|
||||
|
||||
def call_unsafe_for_result(self, default=None, **kwargs) -> typing.Any:
|
||||
return self._parent.call_unsafe_for_result(default, **kwargs)
|
||||
def call_unsafe(self, **kwargs) -> typing.List[typing.Any]:
|
||||
return self._parent.call_unsafe(**kwargs)
|
||||
def call_unsafe_limited(self, maximum: int, **kwargs
|
||||
) -> typing.List[typing.Any]:
|
||||
return self._parent.call_unsafe_limited(maximum, **kwargs)
|
||||
|
||||
def get_hooks(self) -> typing.List[EventCallback]:
|
||||
return self._parent.get_hooks()
|
||||
def get_children(self) -> typing.List[EventHook]:
|
||||
return self._parent.get_children()
|
||||
|
|
|
@ -1,28 +1,18 @@
|
|||
|
||||
|
||||
class ExportsContext(object):
|
||||
def __init__(self, parent, context):
|
||||
self._parent = parent
|
||||
self.context = context
|
||||
|
||||
def add(self, setting, value):
|
||||
self._parent._context_add(self.context, setting, value)
|
||||
def get_all(self, setting):
|
||||
return self._parent.get_all(setting)
|
||||
import typing
|
||||
|
||||
class Exports(object):
|
||||
def __init__(self):
|
||||
self._exports = {}
|
||||
self._context_exports = {}
|
||||
|
||||
def new_context(self, context):
|
||||
def new_context(self, context: str) -> "ExportsContext":
|
||||
return ExportsContext(self, context)
|
||||
|
||||
def add(self, setting, value):
|
||||
def add(self, setting: str, value: typing.Any):
|
||||
self._add(None, setting, value)
|
||||
def _context_add(self, context, setting, value):
|
||||
def _context_add(self, context: str, setting: str, value: typing.Any):
|
||||
self._add(context, setting, value)
|
||||
def _add(self, context, setting, value):
|
||||
def _add(self, context: str, setting: str, value: typing.Any):
|
||||
if context == None:
|
||||
if not setting in self_exports:
|
||||
self._exports[setting] = []
|
||||
|
@ -34,11 +24,21 @@ class Exports(object):
|
|||
self._context_exports[context][setting] = []
|
||||
self._context_exports[context][setting].append(value)
|
||||
|
||||
def get_all(self, setting):
|
||||
def get_all(self, setting: str) -> typing.List[typing.Any]:
|
||||
return self._exports.get(setting, []) + sum([
|
||||
exports.get(setting, []) for exports in
|
||||
self._context_exports.values()], [])
|
||||
|
||||
def purge_context(self, context):
|
||||
def purge_context(self, context: str):
|
||||
if context in self._context_exports:
|
||||
del self._context_exports[context]
|
||||
|
||||
class ExportsContext(object):
|
||||
def __init__(self, parent: Exports, context: str):
|
||||
self._parent = parent
|
||||
self.context = context
|
||||
|
||||
def add(self, setting: str, value: typing.Any):
|
||||
self._parent._context_add(self.context, setting, value)
|
||||
def get_all(self, setting: str) -> typing.List[typing.Any]:
|
||||
return self._parent.get_all(setting)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import os, select, socket, sys, threading, time, traceback, uuid
|
||||
import os, select, socket, sys, threading, time, traceback, typing, uuid
|
||||
from src import EventManager, Exports, IRCServer, Logging, ModuleManager
|
||||
from src import Socket, utils
|
||||
|
||||
|
@ -28,14 +28,15 @@ class Bot(object):
|
|||
|
||||
self._trigger_functions = []
|
||||
|
||||
def trigger(self, func=None):
|
||||
def trigger(self, func: typing.Callable[[], typing.Any]=None):
|
||||
self.lock.acquire()
|
||||
if func:
|
||||
self._trigger_functions.append(func)
|
||||
self._trigger_client.send(b"TRIGGER")
|
||||
self.lock.release()
|
||||
|
||||
def add_server(self, server_id, connect=True):
|
||||
def add_server(self, server_id: int, connect: bool = True
|
||||
) -> typing.Optional[IRCServer.Server]:
|
||||
(_, alias, hostname, port, password, ipv4, tls, bindhost, nickname,
|
||||
username, realname) = self.database.servers.get(server_id)
|
||||
|
||||
|
@ -49,20 +50,20 @@ class Bot(object):
|
|||
self.connect(new_server)
|
||||
return new_server
|
||||
|
||||
def add_socket(self, sock):
|
||||
def add_socket(self, sock: socket.socket):
|
||||
self.other_sockets[sock.fileno()] = sock
|
||||
self.poll.register(sock.fileno(), select.EPOLLIN)
|
||||
|
||||
def remove_socket(self, sock):
|
||||
def remove_socket(self, sock: socket.socket):
|
||||
del self.other_sockets[sock.fileno()]
|
||||
self.poll.unregister(sock.fileno())
|
||||
|
||||
def get_server(self, id):
|
||||
def get_server(self, id: int) -> typing.Optional[IRCServer.Server]:
|
||||
for server in self.servers.values():
|
||||
if server.id == id:
|
||||
return server
|
||||
|
||||
def connect(self, server):
|
||||
def connect(self, server: IRCServer.Server) -> bool:
|
||||
try:
|
||||
server.connect()
|
||||
except:
|
||||
|
@ -73,7 +74,7 @@ class Bot(object):
|
|||
self.poll.register(server.fileno(), select.EPOLLOUT)
|
||||
return True
|
||||
|
||||
def next_send(self):
|
||||
def next_send(self) -> typing.Optional[float]:
|
||||
next = None
|
||||
for server in self.servers.values():
|
||||
timeout = server.send_throttle_timeout()
|
||||
|
@ -81,7 +82,7 @@ class Bot(object):
|
|||
next = timeout
|
||||
return next
|
||||
|
||||
def next_ping(self):
|
||||
def next_ping(self) -> typing.Optional[float]:
|
||||
timeouts = []
|
||||
for server in self.servers.values():
|
||||
timeout = server.until_next_ping()
|
||||
|
@ -90,7 +91,8 @@ class Bot(object):
|
|||
if not timeouts:
|
||||
return None
|
||||
return min(timeouts)
|
||||
def next_read_timeout(self):
|
||||
|
||||
def next_read_timeout(self) -> typing.Optional[float]:
|
||||
timeouts = []
|
||||
for server in self.servers.values():
|
||||
timeouts.append(server.until_read_timeout())
|
||||
|
@ -98,7 +100,7 @@ class Bot(object):
|
|||
return None
|
||||
return min(timeouts)
|
||||
|
||||
def get_poll_timeout(self):
|
||||
def get_poll_timeout(self) -> float:
|
||||
timeouts = []
|
||||
timeouts.append(self._timers.next())
|
||||
timeouts.append(self.next_send())
|
||||
|
@ -107,15 +109,15 @@ class Bot(object):
|
|||
timeouts.append(self.cache.next_expiration())
|
||||
return min([timeout for timeout in timeouts if not timeout == None])
|
||||
|
||||
def register_read(self, server):
|
||||
def register_read(self, server: IRCServer.Server):
|
||||
self.poll.modify(server.fileno(), select.EPOLLIN)
|
||||
def register_write(self, server):
|
||||
def register_write(self, server: IRCServer.Server):
|
||||
self.poll.modify(server.fileno(), select.EPOLLOUT)
|
||||
def register_both(self, server):
|
||||
def register_both(self, server: IRCServer.Server):
|
||||
self.poll.modify(server.fileno(),
|
||||
select.EPOLLIN|select.EPOLLOUT)
|
||||
|
||||
def disconnect(self, server):
|
||||
def disconnect(self, server: IRCServer.Server):
|
||||
try:
|
||||
self.poll.unregister(server.fileno())
|
||||
except FileNotFoundError:
|
||||
|
@ -123,23 +125,25 @@ class Bot(object):
|
|||
del self.servers[server.fileno()]
|
||||
|
||||
@utils.hook("timer.reconnect")
|
||||
def reconnect(self, event):
|
||||
def reconnect(self, event: EventManager.Event):
|
||||
server = self.add_server(event["server_id"], False)
|
||||
if self.connect(server):
|
||||
self.servers[server.fileno()] = server
|
||||
else:
|
||||
event["timer"].redo()
|
||||
|
||||
def set_setting(self, setting, value):
|
||||
def set_setting(self, setting: str, value: typing.Any):
|
||||
self.database.bot_settings.set(setting, value)
|
||||
def get_setting(self, setting, default=None):
|
||||
def get_setting(self, setting: str, default: typing.Any=None) -> typing.Any:
|
||||
return self.database.bot_settings.get(setting, default)
|
||||
def find_settings(self, pattern, default=[]):
|
||||
def find_settings(self, pattern: str, default: typing.Any=[]
|
||||
) -> typing.List[typing.Any]:
|
||||
return self.database.bot_settings.find(pattern, default)
|
||||
def find_settings_prefix(self, prefix, default=[]):
|
||||
def find_settings_prefix(self, prefix: str, default: typing.Any=[]
|
||||
) -> typing.List[typing.Any]:
|
||||
return self.database.bot_settings.find_prefix(
|
||||
prefix, default)
|
||||
def del_setting(self, setting):
|
||||
def del_setting(self, setting: str):
|
||||
self.database.bot_settings.delete(setting)
|
||||
|
||||
def run(self):
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import re
|
||||
from src import utils
|
||||
import re, typing
|
||||
from src import IRCBot, utils
|
||||
|
||||
class BufferLine(object):
|
||||
def __init__(self, sender, message, action, tags, from_self, method):
|
||||
def __init__(self, sender: str, message: str, action: bool, tags: dict,
|
||||
from_self: bool, method: str):
|
||||
self.sender = sender
|
||||
self.message = message
|
||||
self.action = action
|
||||
|
@ -11,35 +12,39 @@ class BufferLine(object):
|
|||
self.method = method
|
||||
|
||||
class Buffer(object):
|
||||
def __init__(self, bot, server):
|
||||
def __init__(self, bot: "IRCBot.Bot", server: "IRCServer.Server"):
|
||||
self.bot = bot
|
||||
self.server = server
|
||||
self.lines = []
|
||||
self.max_lines = 64
|
||||
self._skip_next = False
|
||||
|
||||
def _add_message(self, sender, message, action, tags, from_self, method):
|
||||
def _add_message(self, sender: str, message: str, action: bool, tags: dict,
|
||||
from_self: bool, method: str):
|
||||
if not self._skip_next:
|
||||
line = BufferLine(sender, message, action, tags, from_self, method)
|
||||
self.lines.insert(0, line)
|
||||
if len(self.lines) > self.max_lines:
|
||||
self.lines.pop()
|
||||
self._skip_next = False
|
||||
def add_message(self, sender, message, action, tags, from_self=False):
|
||||
def add_message(self, sender: str, message: str, action: bool, tags: dict,
|
||||
from_self: bool=False):
|
||||
self._add_message(sender, message, action, tags, from_self, "PRIVMSG")
|
||||
def add_notice(self, sender, message, tags, from_self=False):
|
||||
def add_notice(self, sender: str, message: str, tags: dict,
|
||||
from_self: bool=False):
|
||||
self._add_message(sender, message, False, tags, from_self, "NOTICE")
|
||||
|
||||
def get(self, index=0, **kwargs):
|
||||
def get(self, index: int=0, **kwargs) -> typing.Optional[BufferLine]:
|
||||
from_self = kwargs.get("from_self", True)
|
||||
for line in self.lines:
|
||||
if line.from_self and not from_self:
|
||||
continue
|
||||
return line
|
||||
def find(self, pattern, **kwargs):
|
||||
def find(self, pattern: typing.Union[str, typing.Pattern[str]], **kwargs
|
||||
) -> typing.Optional[BufferLine]:
|
||||
from_self = kwargs.get("from_self", True)
|
||||
for_user = kwargs.get("for_user", "")
|
||||
for_user = utils.irc.lower(self.server, for_user
|
||||
for_user = utils.irc.lower(self.server.case_mapping, for_user
|
||||
) if for_user else None
|
||||
not_pattern = kwargs.get("not_pattern", None)
|
||||
for line in self.lines:
|
||||
|
@ -48,8 +53,8 @@ class Buffer(object):
|
|||
elif re.search(pattern, line.message):
|
||||
if not_pattern and re.search(not_pattern, line.message):
|
||||
continue
|
||||
if for_user and not utils.irc.lower(self.server, line.sender
|
||||
) == for_user:
|
||||
if for_user and not utils.irc.lower(self.server.case_mapping,
|
||||
line.sender) == for_user:
|
||||
continue
|
||||
return line
|
||||
def skip_next(self):
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import uuid
|
||||
from src import IRCBuffer, IRCObject, utils
|
||||
import typing, uuid
|
||||
from src import IRCBot, IRCBuffer, IRCObject, IRCServer, IRCUser, utils
|
||||
|
||||
class Channel(IRCObject.Object):
|
||||
def __init__(self, name, id, server, bot):
|
||||
self.name = utils.irc.lower(server, name)
|
||||
def __init__(self, name: str, id, server: "IRCServer.Server",
|
||||
bot: "IRCBot.Bot"):
|
||||
self.name = utils.irc.lower(server.case_mapping, name)
|
||||
self.id = id
|
||||
self.server = server
|
||||
self.bot = bot
|
||||
|
@ -18,23 +19,24 @@ class Channel(IRCObject.Object):
|
|||
self.created_timestamp = None
|
||||
self.buffer = IRCBuffer.Buffer(bot, server)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "IRCChannel.Channel(%s|%s)" % (self.server.name, self.name)
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def set_topic(self, topic):
|
||||
def set_topic(self, topic: str):
|
||||
self.topic = topic
|
||||
def set_topic_setter(self, nickname, username=None, hostname=None):
|
||||
def set_topic_setter(self, nickname: str, username: str=None,
|
||||
hostname: str=None):
|
||||
self.topic_setter_nickname = nickname
|
||||
self.topic_setter_username = username
|
||||
self.topic_setter_hostname = hostname
|
||||
def set_topic_time(self, unix_timestamp):
|
||||
def set_topic_time(self, unix_timestamp: int):
|
||||
self.topic_time = unix_timestamp
|
||||
|
||||
def add_user(self, user):
|
||||
def add_user(self, user: IRCUser.User):
|
||||
self.users.add(user)
|
||||
def remove_user(self, user):
|
||||
def remove_user(self, user: IRCUser.User):
|
||||
self.users.remove(user)
|
||||
for mode in list(self.modes.keys()):
|
||||
if mode in self.server.prefix_modes and user in self.modes[mode]:
|
||||
|
@ -43,10 +45,10 @@ class Channel(IRCObject.Object):
|
|||
del self.modes[mode]
|
||||
if user in self.user_modes:
|
||||
del self.user_modes[user]
|
||||
def has_user(self, user):
|
||||
def has_user(self, user: IRCUser.User) -> bool:
|
||||
return user in self.users
|
||||
|
||||
def add_mode(self, mode, arg=None):
|
||||
def add_mode(self, mode: str, arg: str=None):
|
||||
if not mode in self.modes:
|
||||
self.modes[mode] = set([])
|
||||
if arg:
|
||||
|
@ -59,7 +61,7 @@ class Channel(IRCObject.Object):
|
|||
self.user_modes[user].add(mode)
|
||||
else:
|
||||
self.modes[mode].add(arg.lower())
|
||||
def remove_mode(self, mode, arg=None):
|
||||
def remove_mode(self, mode: str, arg: str=None):
|
||||
if not arg:
|
||||
del self.modes[mode]
|
||||
else:
|
||||
|
@ -76,63 +78,70 @@ class Channel(IRCObject.Object):
|
|||
self.modes[mode].discard(arg.lower())
|
||||
if not len(self.modes[mode]):
|
||||
del self.modes[mode]
|
||||
def change_mode(self, remove, mode, arg=None):
|
||||
def change_mode(self, remove: bool, mode: str, arg: str=None):
|
||||
if remove:
|
||||
self.remove_mode(mode, arg)
|
||||
else:
|
||||
self.add_mode(mode, arg)
|
||||
|
||||
def set_setting(self, setting, value):
|
||||
def set_setting(self, setting: str, value: typing.Any):
|
||||
self.bot.database.channel_settings.set(self.id, setting, value)
|
||||
def get_setting(self, setting, default=None):
|
||||
def get_setting(self, setting: str, default: typing.Any=None
|
||||
) -> typing.Any:
|
||||
return self.bot.database.channel_settings.get(self.id, setting,
|
||||
default)
|
||||
def find_settings(self, pattern, default=[]):
|
||||
def find_settings(self, pattern: str, default: typing.Any=[]
|
||||
) -> typing.List[typing.Any]:
|
||||
return self.bot.database.channel_settings.find(self.id, pattern,
|
||||
default)
|
||||
def find_settings_prefix(self, prefix, default=[]):
|
||||
def find_settings_prefix(self, prefix: str, default: typing.Any=[]
|
||||
) -> typing.List[typing.Any]:
|
||||
return self.bot.database.channel_settings.find_prefix(self.id,
|
||||
prefix, default)
|
||||
def del_setting(self, setting):
|
||||
def del_setting(self, setting: str):
|
||||
self.bot.database.channel_settings.delete(self.id, setting)
|
||||
|
||||
def set_user_setting(self, user_id, setting, value):
|
||||
def set_user_setting(self, user_id: int, setting: str, value: typing.Any):
|
||||
self.bot.database.user_channel_settings.set(user_id, self.id,
|
||||
setting, value)
|
||||
def get_user_setting(self, user_id, setting, default=None):
|
||||
def get_user_setting(self, user_id: int, setting: str,
|
||||
default: typing.Any=None) -> typing.Any:
|
||||
return self.bot.database.user_channel_settings.get(user_id,
|
||||
self.id, setting, default)
|
||||
def find_user_settings(self, user_i, pattern, default=[]):
|
||||
def find_user_settings(self, user_id: int, pattern: str,
|
||||
default: typing.Any=[]) -> typing.List[typing.Any]:
|
||||
return self.bot.database.user_channel_settings.find(user_id,
|
||||
self.id, pattern, default)
|
||||
def find_user_settings_prefix(self, user_id, prefix, default=[]):
|
||||
def find_user_settings_prefix(self, user_id: int, prefix: str,
|
||||
default: typing.Any=[]) -> typing.List[typing.Any]:
|
||||
return self.bot.database.user_channel_settings.find_prefix(
|
||||
user_id, self.id, prefix, default)
|
||||
def del_user_setting(self, user_id, setting):
|
||||
def del_user_setting(self, user_id: int, setting: str):
|
||||
self.bot.database.user_channel_settings.delete(user_id, self.id,
|
||||
setting)
|
||||
def find_all_by_setting(self, setting, default=[]):
|
||||
def find_all_by_setting(self, setting: str, default: typing.Any=[]
|
||||
) -> typing.List[typing.Any]:
|
||||
return self.bot.database.user_channel_settings.find_all_by_setting(
|
||||
self.id, setting, default)
|
||||
|
||||
def send_message(self, text, prefix=None, tags={}):
|
||||
def send_message(self, text: str, prefix: str=None, tags: dict={}):
|
||||
self.server.send_message(self.name, text, prefix=prefix, tags=tags)
|
||||
def send_notice(self, text, prefix=None, tags={}):
|
||||
def send_notice(self, text: str, prefix: str=None, tags: dict={}):
|
||||
self.server.send_notice(self.name, text, prefix=prefix, tags=tags)
|
||||
def send_mode(self, mode=None, target=None):
|
||||
def send_mode(self, mode: str=None, target: str=None):
|
||||
self.server.send_mode(self.name, mode, target)
|
||||
def send_kick(self, target, reason=None):
|
||||
def send_kick(self, target: str, reason: str=None):
|
||||
self.server.send_kick(self.name, target, reason)
|
||||
def send_ban(self, hostmask):
|
||||
def send_ban(self, hostmask: str):
|
||||
self.server.send_mode(self.name, "+b", hostmask)
|
||||
def send_unban(self, hostmask):
|
||||
def send_unban(self, hostmask: str):
|
||||
self.server.send_mode(self.name, "-b", hostmask)
|
||||
def send_topic(self, topic):
|
||||
def send_topic(self, topic: str):
|
||||
self.server.send_topic(self.name, topic)
|
||||
def send_part(self, reason=None):
|
||||
def send_part(self, reason: str=None):
|
||||
self.server.send_part(self.name, reason)
|
||||
|
||||
def mode_or_above(self, user, mode):
|
||||
def mode_or_above(self, user: IRCUser.User, mode: str) -> bool:
|
||||
mode_orders = list(self.server.prefix_modes)
|
||||
mode_index = mode_orders.index(mode)
|
||||
for mode in mode_orders[:mode_index+1]:
|
||||
|
@ -140,8 +149,8 @@ class Channel(IRCObject.Object):
|
|||
return True
|
||||
return False
|
||||
|
||||
def has_mode(self, user, mode):
|
||||
def has_mode(self, user: IRCUser.User, mode: str) -> bool:
|
||||
return user in self.modes.get(mode, [])
|
||||
|
||||
def get_user_status(self, user):
|
||||
def get_user_status(self, user: IRCUser.User) -> typing.Set:
|
||||
return self.user_modes.get(user, [])
|
||||
|
|
137
src/IRCServer.py
137
src/IRCServer.py
|
@ -1,5 +1,5 @@
|
|||
import collections, socket, ssl, sys, time
|
||||
from src import IRCChannel, IRCObject, IRCUser, utils
|
||||
import collections, socket, ssl, sys, time, typing
|
||||
from src import EventManager, IRCBot, IRCChannel, IRCObject, IRCUser, utils
|
||||
|
||||
THROTTLE_LINES = 4
|
||||
THROTTLE_SECONDS = 1
|
||||
|
@ -7,8 +7,12 @@ READ_TIMEOUT_SECONDS = 120
|
|||
PING_INTERVAL_SECONDS = 30
|
||||
|
||||
class Server(IRCObject.Object):
|
||||
def __init__(self, bot, events, id, alias, hostname, port, password,
|
||||
ipv4, tls, bindhost, nickname, username, realname):
|
||||
def __init__(self,
|
||||
bot: "IRCBot.Bot",
|
||||
events: EventManager.EventHook,
|
||||
id: int, alias: str, hostname: str, port: int, password: str,
|
||||
ipv4: bool, tls: bool, bindhost: str,
|
||||
nickname: str, username: str, realname: str):
|
||||
self.connected = False
|
||||
self.bot = bot
|
||||
self.events = events
|
||||
|
@ -121,77 +125,80 @@ class Server(IRCObject.Object):
|
|||
except:
|
||||
pass
|
||||
|
||||
def set_setting(self, setting, value):
|
||||
def set_setting(self, setting: str, value: typing.Any):
|
||||
self.bot.database.server_settings.set(self.id, setting,
|
||||
value)
|
||||
def get_setting(self, setting, default=None):
|
||||
def get_setting(self, setting: str, default: typing.Any=None):
|
||||
return self.bot.database.server_settings.get(self.id,
|
||||
setting, default)
|
||||
def find_settings(self, pattern, default=[]):
|
||||
def find_settings(self, pattern: str, default: typing.Any=[]):
|
||||
return self.bot.database.server_settings.find(self.id,
|
||||
pattern, default)
|
||||
def find_settings_prefix(self, prefix, default=[]):
|
||||
def find_settings_prefix(self, prefix: str, default: typing.Any=[]):
|
||||
return self.bot.database.server_settings.find_prefix(
|
||||
self.id, prefix, default)
|
||||
def del_setting(self, setting):
|
||||
def del_setting(self, setting: str):
|
||||
self.bot.database.server_settings.delete(self.id, setting)
|
||||
|
||||
def get_user_setting(self, nickname, setting, default=None):
|
||||
def get_user_setting(self, nickname: str, setting: str,
|
||||
default: typing.Any=None):
|
||||
user_id = self.get_user_id(nickname)
|
||||
return self.bot.database.user_settings.get(user_id, setting, default)
|
||||
def set_user_setting(self, nickname, setting, value):
|
||||
def set_user_setting(self, nickname: str, setting: str, value: typing.Any):
|
||||
user_id = self.get_user_id(nickname)
|
||||
self.bot.database.user_settings.set(user_id, setting, value)
|
||||
|
||||
def get_all_user_settings(self, setting, default=[]):
|
||||
def get_all_user_settings(self, setting: str, default: typing.Any=[]):
|
||||
return self.bot.database.user_settings.find_all_by_setting(
|
||||
self.id, setting, default)
|
||||
def find_all_user_channel_settings(self, setting, default=[]):
|
||||
def find_all_user_channel_settings(self, setting: str,
|
||||
default: typing.Any=[]):
|
||||
return self.bot.database.user_channel_settings.find_all_by_setting(
|
||||
self.id, setting, default)
|
||||
|
||||
def set_own_nickname(self, nickname):
|
||||
def set_own_nickname(self, nickname: str):
|
||||
self.nickname = nickname
|
||||
self.nickname_lower = utils.irc.lower(self, nickname)
|
||||
def is_own_nickname(self, nickname):
|
||||
self.nickname_lower = utils.irc.lower(self.case_mapping, nickname)
|
||||
def is_own_nickname(self, nickname: str):
|
||||
return utils.irc.equals(self, nickname, self.nickname)
|
||||
|
||||
def add_own_mode(self, mode, arg=None):
|
||||
def add_own_mode(self, mode: str, arg: str=None):
|
||||
self.own_modes[mode] = arg
|
||||
def remove_own_mode(self, mode):
|
||||
def remove_own_mode(self, mode: str):
|
||||
del self.own_modes[mode]
|
||||
def change_own_mode(self, remove, mode, arg=None):
|
||||
def change_own_mode(self, remove: bool, mode: str, arg: str=None):
|
||||
if remove:
|
||||
self.remove_own_mode(mode)
|
||||
else:
|
||||
self.add_own_mode(mode, arg)
|
||||
|
||||
def has_user(self, nickname):
|
||||
return utils.irc.lower(self, nickname) in self.users
|
||||
def get_user(self, nickname, create=True):
|
||||
def has_user(self, nickname: str):
|
||||
return utils.irc.lower(self.case_mapping, nickname) in self.users
|
||||
def get_user(self, nickname: str, create: bool=True):
|
||||
if not self.has_user(nickname) and create:
|
||||
user_id = self.get_user_id(nickname)
|
||||
new_user = IRCUser.User(nickname, user_id, self, self.bot)
|
||||
self.events.on("new.user").call(user=new_user, server=self)
|
||||
self.users[new_user.nickname_lower] = new_user
|
||||
self.new_users.add(new_user)
|
||||
return self.users.get(utils.irc.lower(self, nickname), None)
|
||||
def get_user_id(self, nickname):
|
||||
return self.users.get(utils.irc.lower(self.case_mapping, nickname),
|
||||
None)
|
||||
def get_user_id(self, nickname: str):
|
||||
self.bot.database.users.add(self.id, nickname)
|
||||
return self.bot.database.users.get_id(self.id, nickname)
|
||||
def remove_user(self, user):
|
||||
def remove_user(self, user: IRCUser.User):
|
||||
del self.users[user.nickname_lower]
|
||||
for channel in user.channels:
|
||||
channel.remove_user(user)
|
||||
|
||||
def change_user_nickname(self, old_nickname, new_nickname):
|
||||
user = self.users.pop(utils.irc.lower(self, old_nickname))
|
||||
def change_user_nickname(self, old_nickname: str, new_nickname: str):
|
||||
user = self.users.pop(utils.irc.lower(self.case_mapping, old_nickname))
|
||||
user._id = self.get_user_id(new_nickname)
|
||||
self.users[utils.irc.lower(self, new_nickname)] = user
|
||||
def has_channel(self, channel_name):
|
||||
self.users[utils.irc.lower(self.case_mapping, new_nickname)] = user
|
||||
def has_channel(self, channel_name: str):
|
||||
return channel_name[0] in self.channel_types and utils.irc.lower(
|
||||
self, channel_name) in self.channels
|
||||
def get_channel(self, channel_name):
|
||||
self.case_mapping, channel_name) in self.channels
|
||||
def get_channel(self, channel_name: str):
|
||||
if not self.has_channel(channel_name):
|
||||
channel_id = self.get_channel_id(channel_name)
|
||||
new_channel = IRCChannel.Channel(channel_name, channel_id,
|
||||
|
@ -199,15 +206,15 @@ class Server(IRCObject.Object):
|
|||
self.events.on("new.channel").call(channel=new_channel,
|
||||
server=self)
|
||||
self.channels[new_channel.name] = new_channel
|
||||
return self.channels[utils.irc.lower(self, channel_name)]
|
||||
def get_channel_id(self, channel_name):
|
||||
return self.channels[utils.irc.lower(self.case_mapping, channel_name)]
|
||||
def get_channel_id(self, channel_name: str):
|
||||
self.bot.database.channels.add(self.id, channel_name)
|
||||
return self.bot.database.channels.get_id(self.id, channel_name)
|
||||
def remove_channel(self, channel):
|
||||
def remove_channel(self, channel: IRCChannel.Channel):
|
||||
for user in channel.users:
|
||||
user.part_channel(channel)
|
||||
del self.channels[channel.name]
|
||||
def parse_data(self, line):
|
||||
def parse_data(self, line: str):
|
||||
if not line:
|
||||
return
|
||||
self.events.on("raw").call_unsafe(server=self, line=line)
|
||||
|
@ -271,7 +278,7 @@ class Server(IRCObject.Object):
|
|||
def read_timed_out(self):
|
||||
return self.until_read_timeout == 0
|
||||
|
||||
def send(self, data):
|
||||
def send(self, data: str):
|
||||
returned = self.events.on("preprocess.send").call_unsafe_for_result(
|
||||
server=self, line=data)
|
||||
line = returned or data
|
||||
|
@ -314,16 +321,16 @@ class Server(IRCObject.Object):
|
|||
time_left = time_left-now
|
||||
return time_left
|
||||
|
||||
def send_user(self, username, realname):
|
||||
def send_user(self, username: str, realname: str):
|
||||
self.send("USER %s 0 * :%s" % (username, realname))
|
||||
def send_nick(self, nickname):
|
||||
def send_nick(self, nickname: str):
|
||||
self.send("NICK %s" % nickname)
|
||||
|
||||
def send_capibility_ls(self):
|
||||
self.send("CAP LS 302")
|
||||
def queue_capability(self, capability):
|
||||
def queue_capability(self, capability: str):
|
||||
self._capability_queue.add(capability)
|
||||
def queue_capabilities(self, capabilities):
|
||||
def queue_capabilities(self, capabilities: typing.List[str]):
|
||||
self._capability_queue.update(capabilities)
|
||||
def send_capability_queue(self):
|
||||
if self.has_capability_queue():
|
||||
|
@ -332,46 +339,46 @@ class Server(IRCObject.Object):
|
|||
self.send_capability_request(capabilities)
|
||||
def has_capability_queue(self):
|
||||
return bool(len(self._capability_queue))
|
||||
def send_capability_request(self, capability):
|
||||
def send_capability_request(self, capability: str):
|
||||
self.send("CAP REQ :%s" % capability)
|
||||
def send_capability_end(self):
|
||||
self.send("CAP END")
|
||||
def send_authenticate(self, text):
|
||||
def send_authenticate(self, text: str):
|
||||
self.send("AUTHENTICATE %s" % text)
|
||||
def send_starttls(self):
|
||||
self.send("STARTTLS")
|
||||
|
||||
def waiting_for_capabilities(self):
|
||||
return bool(len(self._capabilities_waiting))
|
||||
def wait_for_capability(self, capability):
|
||||
def wait_for_capability(self, capability: str):
|
||||
self._capabilities_waiting.add(capability)
|
||||
def capability_done(self, capability):
|
||||
def capability_done(self, capability: str):
|
||||
self._capabilities_waiting.remove(capability)
|
||||
if not self._capabilities_waiting:
|
||||
self.send_capability_end()
|
||||
|
||||
def send_pass(self, password):
|
||||
def send_pass(self, password: str):
|
||||
self.send("PASS %s" % password)
|
||||
|
||||
def send_ping(self, nonce="hello"):
|
||||
def send_ping(self, nonce: str="hello"):
|
||||
self.send("PING :%s" % nonce)
|
||||
def send_pong(self, nonce="hello"):
|
||||
def send_pong(self, nonce: str="hello"):
|
||||
self.send("PONG :%s" % nonce)
|
||||
|
||||
def try_rejoin(self, event):
|
||||
def try_rejoin(self, event: EventManager.Event):
|
||||
if event["server_id"] == self.id and event["channel_name"
|
||||
] in self.attempted_join:
|
||||
self.send_join(event["channel_name"], event["key"])
|
||||
def send_join(self, channel_name, key=None):
|
||||
def send_join(self, channel_name: str, key: str=None):
|
||||
self.send("JOIN %s%s" % (channel_name,
|
||||
"" if key == None else " %s" % key))
|
||||
def send_part(self, channel_name, reason=None):
|
||||
def send_part(self, channel_name: str, reason: str=None):
|
||||
self.send("PART %s%s" % (channel_name,
|
||||
"" if reason == None else " %s" % reason))
|
||||
def send_quit(self, reason="Leaving"):
|
||||
def send_quit(self, reason: str="Leaving"):
|
||||
self.send("QUIT :%s" % reason)
|
||||
|
||||
def _tag_str(self, tags):
|
||||
def _tag_str(self, tags: dict):
|
||||
tag_str = ""
|
||||
for tag, value in tags.items():
|
||||
if tag_str:
|
||||
|
@ -383,7 +390,8 @@ class Server(IRCObject.Object):
|
|||
tag_str = "@%s " % tag_str
|
||||
return tag_str
|
||||
|
||||
def send_message(self, target, message, prefix=None, tags={}):
|
||||
def send_message(self, target: str, message: str, prefix: str=None,
|
||||
tags: dict={}):
|
||||
full_message = message if not prefix else prefix+message
|
||||
self.send("%sPRIVMSG %s :%s" % (self._tag_str(tags), target,
|
||||
full_message))
|
||||
|
@ -408,7 +416,8 @@ class Server(IRCObject.Object):
|
|||
message=full_message, message_split=full_message_split,
|
||||
user=user, action=action, server=self)
|
||||
|
||||
def send_notice(self, target, message, prefix=None, tags={}):
|
||||
def send_notice(self, target: str, message: str, prefix: str=None,
|
||||
tags: dict={}):
|
||||
full_message = message if not prefix else prefix+message
|
||||
self.send("%sNOTICE %s :%s" % (self._tag_str(tags), target,
|
||||
full_message))
|
||||
|
@ -419,31 +428,31 @@ class Server(IRCObject.Object):
|
|||
self.get_user(target).buffer.add_notice(None, message, tags,
|
||||
True)
|
||||
|
||||
def send_mode(self, target, mode=None, args=None):
|
||||
def send_mode(self, target: str, mode: str=None, args: str=None):
|
||||
self.send("MODE %s%s%s" % (target, "" if mode == None else " %s" % mode,
|
||||
"" if args == None else " %s" % args))
|
||||
|
||||
def send_topic(self, channel_name, topic):
|
||||
def send_topic(self, channel_name: str, topic: str):
|
||||
self.send("TOPIC %s :%s" % (channel_name, topic))
|
||||
def send_kick(self, channel_name, target, reason=None):
|
||||
def send_kick(self, channel_name: str, target: str, reason: str=None):
|
||||
self.send("KICK %s %s%s" % (channel_name, target,
|
||||
"" if reason == None else " :%s" % reason))
|
||||
def send_names(self, channel_name):
|
||||
def send_names(self, channel_name: str):
|
||||
self.send("NAMES %s" % channel_name)
|
||||
def send_list(self, search_for=None):
|
||||
def send_list(self, search_for: str=None):
|
||||
self.send(
|
||||
"LIST%s" % "" if search_for == None else " %s" % search_for)
|
||||
def send_invite(self, target, channel_name):
|
||||
def send_invite(self, target: str, channel_name: str):
|
||||
self.send("INVITE %s %s" % (target, channel_name))
|
||||
|
||||
def send_whois(self, target):
|
||||
def send_whois(self, target: str):
|
||||
self.send("WHOIS %s" % target)
|
||||
def send_whowas(self, target, amount=None, server=None):
|
||||
def send_whowas(self, target: str, amount: int=None, server: str=None):
|
||||
self.send("WHOWAS %s%s%s" % (target,
|
||||
"" if amount == None else " %s" % amount,
|
||||
"" if server == None else " :%s" % server))
|
||||
def send_who(self, filter=None):
|
||||
def send_who(self, filter: str=None):
|
||||
self.send("WHO%s" % ("" if filter == None else " %s" % filter))
|
||||
def send_whox(self, mask, filter, fields, label=None):
|
||||
def send_whox(self, mask: str, filter: str, fields: str, label: str=None):
|
||||
self.send("WHO %s %s%%%s%s" % (mask, filter, fields,
|
||||
","+label if label else ""))
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import uuid
|
||||
from src import IRCBuffer, IRCObject, utils
|
||||
import typing, uuid
|
||||
from src import IRCBot, IRCChannel, IRCBuffer, IRCObject, IRCServer, utils
|
||||
|
||||
class User(IRCObject.Object):
|
||||
def __init__(self, nickname, id, server, bot):
|
||||
def __init__(self, nickname: str, id: int, server: "IRCServer.Server",
|
||||
bot: "IRCBot.Bot"):
|
||||
self.server = server
|
||||
self.set_nickname(nickname)
|
||||
self._id = id
|
||||
|
@ -20,46 +21,51 @@ class User(IRCObject.Object):
|
|||
self.away = False
|
||||
self.buffer = IRCBuffer.Buffer(bot, server)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "IRCUser.User(%s|%s)" % (self.server.name, self.name)
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.nickname
|
||||
|
||||
def get_id(self):
|
||||
def get_id(self)-> int:
|
||||
return (self.identified_account_id_override or
|
||||
self.identified_account_id or self._id)
|
||||
def get_identified_account(self):
|
||||
def get_identified_account(self) -> str:
|
||||
return (self.identified_account_override or self.identified_account)
|
||||
|
||||
def set_nickname(self, nickname):
|
||||
def set_nickname(self, nickname: str):
|
||||
self.nickname = nickname
|
||||
self.nickname_lower = utils.irc.lower(self.server, nickname)
|
||||
self.nickname_lower = utils.irc.lower(self.server.case_mapping,
|
||||
nickname)
|
||||
self.name = self.nickname_lower
|
||||
def join_channel(self, channel):
|
||||
def join_channel(self, channel: "IRCChannel.Channel"):
|
||||
self.channels.add(channel)
|
||||
def part_channel(self, channel):
|
||||
def part_channel(self, channel: "IRCChannel.Channel"):
|
||||
self.channels.remove(channel)
|
||||
def set_setting(self, setting, value):
|
||||
|
||||
def set_setting(self, setting: str, value: typing.Any):
|
||||
self.bot.database.user_settings.set(self.get_id(), setting, value)
|
||||
def get_setting(self, setting, default=None):
|
||||
def get_setting(self, setting: str, default: typing.Any=None) -> typing.Any:
|
||||
return self.bot.database.user_settings.get(self.get_id(), setting,
|
||||
default)
|
||||
def find_settings(self, pattern, default=[]):
|
||||
def find_settings(self, pattern: str, default: typing.Any=[]
|
||||
) -> typing.List[typing.Any]:
|
||||
return self.bot.database.user_settings.find(self.get_id(), pattern,
|
||||
default)
|
||||
def find_settings_prefix(self, prefix, default=[]):
|
||||
def find_settings_prefix(self, prefix: str, default: typing.Any=[]
|
||||
) -> typing.List[typing.Any]:
|
||||
return self.bot.database.user_settings.find_prefix(self.get_id(),
|
||||
prefix, default)
|
||||
def del_setting(self, setting):
|
||||
self.bot.database.user_settings.delete(self.get_id(), setting)
|
||||
def get_channel_settings_per_setting(self, setting, default=[]):
|
||||
def get_channel_settings_per_setting(self, setting: str,
|
||||
default: typing.Any=[]) -> typing.List[typing.Any]:
|
||||
return self.bot.database.user_channel_settings.find_by_setting(
|
||||
self.get_id(), setting, default)
|
||||
|
||||
def send_message(self, message, prefix=None, tags={}):
|
||||
def send_message(self, message: str, prefix: str=None, tags: dict={}):
|
||||
self.server.send_message(self.nickname, message, prefix=prefix,
|
||||
tags=tags)
|
||||
def send_notice(self, text, prefix=None, tags={}):
|
||||
def send_notice(self, text: str, prefix: str=None, tags: dict={}):
|
||||
self.server.send_notice(self.nickname, text, prefix=prefix, tags=tags)
|
||||
def send_ctcp_response(self, command, args):
|
||||
def send_ctcp_response(self, command: str, args: str):
|
||||
self.send_notice("\x01%s %s\x01" % (command, args))
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import logging, logging.handlers, os, sys, time
|
||||
import logging, logging.handlers, os, sys, time, typing
|
||||
|
||||
LEVELS = {
|
||||
"trace": logging.DEBUG-1,
|
||||
|
@ -23,7 +23,7 @@ class BitBotFormatter(logging.Formatter):
|
|||
return s
|
||||
|
||||
class Log(object):
|
||||
def __init__(self, level, location):
|
||||
def __init__(self, level: str, location: str):
|
||||
logging.addLevelName(LEVELS["trace"], "TRACE")
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -49,17 +49,17 @@ class Log(object):
|
|||
file_handler.setFormatter(formatter)
|
||||
self.logger.addHandler(file_handler)
|
||||
|
||||
def trace(self, message, params, **kwargs):
|
||||
def trace(self, message: str, params: typing.List, **kwargs):
|
||||
self._log(message, params, LEVELS["trace"], kwargs)
|
||||
def debug(self, message, params, **kwargs):
|
||||
def debug(self, message: str, params: typing.List, **kwargs):
|
||||
self._log(message, params, logging.DEBUG, kwargs)
|
||||
def info(self, message, params, **kwargs):
|
||||
def info(self, message: str, params: typing.List, **kwargs):
|
||||
self._log(message, params, logging.INFO, kwargs)
|
||||
def warn(self, message, params, **kwargs):
|
||||
def warn(self, message: str, params: typing.List, **kwargs):
|
||||
self._log(message, params, logging.WARN, kwargs)
|
||||
def error(self, message, params, **kwargs):
|
||||
def error(self, message: str, params: typing.List, **kwargs):
|
||||
self._log(message, params, logging.ERROR, kwargs)
|
||||
def critical(self, message, params, **kwargs):
|
||||
def critical(self, message: str, params: typing.List, **kwargs):
|
||||
self._log(message, params, logging.CRITICAL, kwargs)
|
||||
def _log(self, message, params, level, kwargs):
|
||||
def _log(self, message: str, params: typing.List, level: int, kwargs: dict):
|
||||
self.logger.log(level, message, *params, **kwargs)
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
import gc, glob, imp, io, inspect, os, sys, uuid
|
||||
from . import utils
|
||||
|
||||
BITBOT_HOOKS_MAGIC = "__bitbot_hooks"
|
||||
BITBOT_EXPORTS_MAGIC = "__bitbot_exports"
|
||||
import gc, glob, imp, io, inspect, os, sys, typing, uuid
|
||||
from src import Config, EventManager, Exports, IRCBot, Logging, Timers, utils
|
||||
|
||||
class ModuleException(Exception):
|
||||
pass
|
||||
|
@ -22,7 +19,11 @@ class ModuleNotLoadedWarning(ModuleWarning):
|
|||
pass
|
||||
|
||||
class BaseModule(object):
|
||||
def __init__(self, bot, events, exports, timers):
|
||||
def __init__(self,
|
||||
bot: "IRCBot.Bot",
|
||||
events: EventManager.EventHook,
|
||||
exports: Exports.Exports,
|
||||
timers: Timers.Timers):
|
||||
self.bot = bot
|
||||
self.events = events
|
||||
self.exports = exports
|
||||
|
@ -32,7 +33,13 @@ class BaseModule(object):
|
|||
pass
|
||||
|
||||
class ModuleManager(object):
|
||||
def __init__(self, events, exports, timers, config, log, directory):
|
||||
def __init__(self,
|
||||
events: EventManager.EventHook,
|
||||
exports: Exports.Exports,
|
||||
timers: Timers.Timers,
|
||||
config: Config.Config,
|
||||
log: Logging.Log,
|
||||
directory: str):
|
||||
self.events = events
|
||||
self.exports = exports
|
||||
self.config = config
|
||||
|
@ -43,23 +50,24 @@ class ModuleManager(object):
|
|||
self.modules = {}
|
||||
self.waiting_requirement = {}
|
||||
|
||||
def list_modules(self):
|
||||
def list_modules(self) -> typing.List[str]:
|
||||
return sorted(glob.glob(os.path.join(self.directory, "*.py")))
|
||||
|
||||
def _module_name(self, path):
|
||||
def _module_name(self, path: str) -> str:
|
||||
return os.path.basename(path).rsplit(".py", 1)[0].lower()
|
||||
def _module_path(self, name):
|
||||
def _module_path(self, name: str) -> str:
|
||||
return os.path.join(self.directory, "%s.py" % name)
|
||||
def _import_name(self, name):
|
||||
def _import_name(self, name: str) -> str:
|
||||
return "bitbot_%s" % name
|
||||
|
||||
def _get_magic(self, obj, magic, default):
|
||||
def _get_magic(self, obj: typing.Any, magic: str, default: typing.Any
|
||||
) -> typing.Any:
|
||||
return getattr(obj, magic) if hasattr(obj, magic) else default
|
||||
|
||||
def _load_module(self, bot, name):
|
||||
def _load_module(self, bot: "IRCBot.Bot", name: str):
|
||||
path = self._module_path(name)
|
||||
|
||||
for hashflag, value in utils.get_hashflags(path):
|
||||
for hashflag, value in utils.parse.hashflags(path):
|
||||
if hashflag == "ignore":
|
||||
# nope, ignore this module.
|
||||
raise ModuleNotLoadedWarning("module ignored")
|
||||
|
@ -97,10 +105,12 @@ class ModuleManager(object):
|
|||
module_object._name = name.title()
|
||||
for attribute_name in dir(module_object):
|
||||
attribute = getattr(module_object, attribute_name)
|
||||
for hook in self._get_magic(attribute, BITBOT_HOOKS_MAGIC, []):
|
||||
for hook in self._get_magic(attribute,
|
||||
utils.consts.BITBOT_HOOKS_MAGIC, []):
|
||||
context_events.on(hook["event"]).hook(attribute,
|
||||
**hook["kwargs"])
|
||||
for export in self._get_magic(module_object, BITBOT_EXPORTS_MAGIC, []):
|
||||
for export in self._get_magic(module_object,
|
||||
utils.consts.BITBOT_EXPORTS_MAGIC, []):
|
||||
context_exports.add(export["setting"], export["value"])
|
||||
|
||||
module_object._context = context
|
||||
|
@ -111,7 +121,7 @@ class ModuleManager(object):
|
|||
"attempted to be used twice")
|
||||
return module_object
|
||||
|
||||
def load_module(self, bot, name):
|
||||
def load_module(self, bot: "IRCBot.Bot", name: str):
|
||||
try:
|
||||
module = self._load_module(bot, name)
|
||||
except ModuleWarning as warning:
|
||||
|
@ -128,7 +138,8 @@ class ModuleManager(object):
|
|||
self.load_module(bot, requirement_name)
|
||||
self.log.info("Module '%s' loaded", [name])
|
||||
|
||||
def load_modules(self, bot, whitelist=[], blacklist=[]):
|
||||
def load_modules(self, bot: "IRCBot.Bot", whitelist: typing.List[str]=[],
|
||||
blacklist: typing.List[str]=[]):
|
||||
for path in self.list_modules():
|
||||
name = self._module_name(path)
|
||||
if name in whitelist or (not whitelist and not name in blacklist):
|
||||
|
@ -137,7 +148,7 @@ class ModuleManager(object):
|
|||
except ModuleWarning:
|
||||
pass
|
||||
|
||||
def unload_module(self, name):
|
||||
def unload_module(self, name: str):
|
||||
if not name in self.modules:
|
||||
raise ModuleNotFoundException()
|
||||
module = self.modules[name]
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
|
||||
import socket, typing
|
||||
|
||||
class Socket(object):
|
||||
def __init__(self, socket, on_read, encoding="utf8"):
|
||||
def __init__(self, socket: socket.socket,
|
||||
on_read: typing.Callable[["Socket", str], None],
|
||||
encoding: str="utf8"):
|
||||
self.socket = socket
|
||||
self._on_read = on_read
|
||||
self.encoding = encoding
|
||||
|
@ -12,18 +14,18 @@ class Socket(object):
|
|||
self.length = None
|
||||
self.connected = True
|
||||
|
||||
def fileno(self):
|
||||
def fileno(self) -> int:
|
||||
return self.socket.fileno()
|
||||
|
||||
def disconnect(self):
|
||||
self.connected = False
|
||||
|
||||
def _decode(self, s):
|
||||
return s.decode(self.encoding) if self.encoding else s
|
||||
def _encode(self, s):
|
||||
return s.encode(self.encoding) if self.encoding else s
|
||||
def _decode(self, s: bytes) -> str:
|
||||
return s.decode(self.encoding)
|
||||
def _encode(self, s: str) -> bytes:
|
||||
return s.encode(self.encoding)
|
||||
|
||||
def read(self):
|
||||
def read(self) -> typing.Optional[typing.List[str]]:
|
||||
data = self.socket.recv(1024)
|
||||
if not data:
|
||||
return None
|
||||
|
@ -35,17 +37,17 @@ class Socket(object):
|
|||
if data_split[-1]:
|
||||
self._read_buffer = data_split.pop(-1)
|
||||
return [self._decode(data) for data in data_split]
|
||||
return [data.decode(self.encoding)]
|
||||
return [self._decode(data)]
|
||||
|
||||
def parse_data(self, data):
|
||||
def parse_data(self, data: str):
|
||||
self._on_read(self, data)
|
||||
|
||||
def send(self, data):
|
||||
def send(self, data: str):
|
||||
self._write_buffer += self._encode(data)
|
||||
|
||||
def _send(self):
|
||||
self._write_buffer = self._write_buffer[self.socket.send(
|
||||
self._write_buffer):]
|
||||
|
||||
def waiting_send(self):
|
||||
def waiting_send(self) -> bool:
|
||||
return bool(len(self._write_buffer))
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import time, uuid
|
||||
import time, typing, uuid
|
||||
from src import Database, EventManager, Logging
|
||||
|
||||
class Timer(object):
|
||||
def __init__(self, id, context, name, delay, next_due, kwargs):
|
||||
def __init__(self, id: int, context: str, name: str, delay: float,
|
||||
next_due: float, kwargs: dict):
|
||||
self.id = id
|
||||
self.context = context
|
||||
self.name = name
|
||||
|
@ -15,9 +17,9 @@ class Timer(object):
|
|||
|
||||
def set_next_due(self):
|
||||
self.next_due = time.time()+self.delay
|
||||
def due(self):
|
||||
def due(self) -> bool:
|
||||
return self.time_left() <= 0
|
||||
def time_left(self):
|
||||
def time_left(self) -> float:
|
||||
return self.next_due-time.time()
|
||||
|
||||
def redo(self):
|
||||
|
@ -25,42 +27,33 @@ class Timer(object):
|
|||
self.set_next_due()
|
||||
def finish(self):
|
||||
self._done = True
|
||||
def done(self):
|
||||
def done(self) -> bool:
|
||||
return self._done
|
||||
|
||||
class TimersContext(object):
|
||||
def __init__(self, parent, context):
|
||||
self._parent = parent
|
||||
self.context = context
|
||||
def add(self, name, delay, next_due=None, **kwargs):
|
||||
self._parent._add(self.context, name, delay, next_due, None, False,
|
||||
kwargs)
|
||||
def add_persistent(self, name, delay, next_due=None, **kwargs):
|
||||
self._parent._add(None, name, delay, next_due, None, True,
|
||||
kwargs)
|
||||
|
||||
class Timers(object):
|
||||
def __init__(self, database, events, log):
|
||||
def __init__(self, database: Database.Database,
|
||||
events: EventManager.EventHook,
|
||||
log: Logging.Log):
|
||||
self.database = database
|
||||
self.events = events
|
||||
self.log = log
|
||||
self.timers = []
|
||||
self.context_timers = {}
|
||||
|
||||
def new_context(self, context):
|
||||
def new_context(self, context: str) -> "TimersContext":
|
||||
return TimersContext(self, context)
|
||||
|
||||
def setup(self, timers):
|
||||
def setup(self, timers: typing.List[typing.Tuple[str, dict]]):
|
||||
for name, timer in timers:
|
||||
id = name.split("timer-", 1)[1]
|
||||
self._add(timer["name"], None, timer["delay"], timer[
|
||||
"next-due"], id, False, timer["kwargs"])
|
||||
|
||||
def _persist(self, timer):
|
||||
def _persist(self, timer: Timer):
|
||||
self.database.bot_settings.set("timer-%s" % timer.id, {
|
||||
"name": timer.name, "delay": timer.delay,
|
||||
"next-due": timer.next_due, "kwargs": timer.kwargs})
|
||||
def _remove(self, timer):
|
||||
def _remove(self, timer: Timer):
|
||||
if timer.context:
|
||||
self.context_timers[timer.context].remove(timer)
|
||||
if not self.context_timers[timer.context]:
|
||||
|
@ -69,11 +62,13 @@ class Timers(object):
|
|||
self.timers.remove(timer)
|
||||
self.database.bot_settings.delete("timer-%s" % timer.id)
|
||||
|
||||
def add(self, name, delay, next_due=None, **kwargs):
|
||||
def add(self, name: str, delay: float, next_due: float=None, **kwargs):
|
||||
self._add(None, name, delay, next_due, None, False, kwargs)
|
||||
def add_persistent(self, name, delay, next_due=None, **kwargs):
|
||||
def add_persistent(self, name: str, delay: float, next_due: float=None,
|
||||
**kwargs):
|
||||
self._add(None, name, delay, next_due, None, True, kwargs)
|
||||
def _add(self, context, name, delay, next_due, id, persist, kwargs):
|
||||
def _add(self, context: str, name: str, delay: float, next_due: float,
|
||||
id: str, persist: bool, kwargs: dict):
|
||||
id = id or uuid.uuid4().hex
|
||||
timer = Timer(id, context, name, delay, next_due, kwargs)
|
||||
if persist:
|
||||
|
@ -86,13 +81,13 @@ class Timers(object):
|
|||
else:
|
||||
self.timers.append(timer)
|
||||
|
||||
def next(self):
|
||||
def next(self) -> float:
|
||||
times = filter(None, [timer.time_left() for timer in self.get_timers()])
|
||||
if not times:
|
||||
return None
|
||||
return max(min(times), 0)
|
||||
|
||||
def get_timers(self):
|
||||
def get_timers(self) -> typing.List[Timer]:
|
||||
return self.timers + sum(self.context_timers.values(), [])
|
||||
|
||||
def call(self):
|
||||
|
@ -104,6 +99,19 @@ class Timers(object):
|
|||
if timer.done():
|
||||
self._remove(timer)
|
||||
|
||||
def purge_context(self, context):
|
||||
def purge_context(self, context: str):
|
||||
if context in self.context_timers:
|
||||
del self.context_timers[context]
|
||||
|
||||
class TimersContext(object):
|
||||
def __init__(self, parent: Timers, context: str):
|
||||
self._parent = parent
|
||||
self.context = context
|
||||
def add(self, name: str, delay: float, next_due: float=None,
|
||||
**kwargs):
|
||||
self._parent._add(self.context, name, delay, next_due, None, False,
|
||||
kwargs)
|
||||
def add_persistent(self, name: str, delay: float, next_due: float=None,
|
||||
**kwargs):
|
||||
self._parent._add(None, name, delay, next_due, None, True,
|
||||
kwargs)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import decimal, io, re
|
||||
from src import ModuleManager
|
||||
from . import irc, http
|
||||
import decimal, io, re, typing
|
||||
from src.utils import consts, irc, http, parse
|
||||
|
||||
TIME_SECOND = 1
|
||||
TIME_MINUTE = TIME_SECOND*60
|
||||
|
@ -8,7 +7,7 @@ TIME_HOUR = TIME_MINUTE*60
|
|||
TIME_DAY = TIME_HOUR*24
|
||||
TIME_WEEK = TIME_DAY*7
|
||||
|
||||
def time_unit(seconds):
|
||||
def time_unit(seconds: int) -> typing.Tuple[int, str]:
|
||||
since = None
|
||||
unit = None
|
||||
if seconds >= TIME_WEEK:
|
||||
|
@ -29,7 +28,7 @@ def time_unit(seconds):
|
|||
since = int(since)
|
||||
if since > 1:
|
||||
unit = "%ss" % unit # pluralise the unit
|
||||
return [since, unit]
|
||||
return (since, unit)
|
||||
|
||||
REGEX_PRETTYTIME = re.compile("\d+[wdhms]", re.I)
|
||||
|
||||
|
@ -38,7 +37,7 @@ SECONDS_HOURS = SECONDS_MINUTES*60
|
|||
SECONDS_DAYS = SECONDS_HOURS*24
|
||||
SECONDS_WEEKS = SECONDS_DAYS*7
|
||||
|
||||
def from_pretty_time(pretty_time):
|
||||
def from_pretty_time(pretty_time: str) -> typing.Optional[int]:
|
||||
seconds = 0
|
||||
for match in re.findall(REGEX_PRETTYTIME, pretty_time):
|
||||
number, unit = int(match[:-1]), match[-1].lower()
|
||||
|
@ -54,12 +53,14 @@ def from_pretty_time(pretty_time):
|
|||
if seconds > 0:
|
||||
return seconds
|
||||
|
||||
UNIT_MINIMUM = 6
|
||||
UNIT_SECOND = 5
|
||||
UNIT_MINUTE = 4
|
||||
UNIT_HOUR = 3
|
||||
UNIT_DAY = 2
|
||||
UNIT_WEEK = 1
|
||||
def to_pretty_time(total_seconds, minimum_unit=UNIT_SECOND, max_units=6):
|
||||
def to_pretty_time(total_seconds: int, minimum_unit: int=UNIT_SECOND,
|
||||
max_units: int=UNIT_MINIMUM) -> str:
|
||||
minutes, seconds = divmod(total_seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
days, hours = divmod(hours, 24)
|
||||
|
@ -84,7 +85,7 @@ def to_pretty_time(total_seconds, minimum_unit=UNIT_SECOND, max_units=6):
|
|||
units += 1
|
||||
return out
|
||||
|
||||
def parse_number(s):
|
||||
def parse_number(s: str) -> str:
|
||||
try:
|
||||
decimal.Decimal(s)
|
||||
return s
|
||||
|
@ -110,28 +111,18 @@ def parse_number(s):
|
|||
|
||||
IS_TRUE = ["true", "yes", "on", "y"]
|
||||
IS_FALSE = ["false", "no", "off", "n"]
|
||||
def bool_or_none(s):
|
||||
def bool_or_none(s: str) -> typing.Optional[bool]:
|
||||
s = s.lower()
|
||||
if s in IS_TRUE:
|
||||
return True
|
||||
elif s in IS_FALSE:
|
||||
return False
|
||||
def int_or_none(s):
|
||||
def int_or_none(s: str) -> typing.Optional[int]:
|
||||
stripped_s = s.lstrip("0")
|
||||
if stripped_s.isdigit():
|
||||
return int(stripped_s)
|
||||
|
||||
def get_closest_setting(event, setting, default=None):
|
||||
server = event["server"]
|
||||
if "channel" in event:
|
||||
closest = event["channel"]
|
||||
elif "target" in event and "is_channel" in event and event["is_channel"]:
|
||||
closest = event["target"]
|
||||
else:
|
||||
closest = event["user"]
|
||||
return closest.get_setting(setting, server.get_setting(setting, default))
|
||||
|
||||
def prevent_highlight(nickname):
|
||||
def prevent_highlight(nickname: str) -> str:
|
||||
return nickname[0]+"\u200c"+nickname[1:]
|
||||
|
||||
class EventError(Exception):
|
||||
|
@ -139,80 +130,34 @@ class EventError(Exception):
|
|||
class EventsResultsError(EventError):
|
||||
def __init__(self):
|
||||
EventError.__init__(self, "Failed to load results")
|
||||
class EventsNotEnoughArgsError(EventError):
|
||||
def __init__(self, n):
|
||||
EventError.__init__(self, "Not enough arguments (minimum %d)" % n)
|
||||
class EventsUsageError(EventError):
|
||||
def __init__(self, usage):
|
||||
EventError.__init__(self, "Not enough arguments, usage: %s" % usage)
|
||||
|
||||
def _set_get_append(obj, setting, item):
|
||||
def _set_get_append(obj: typing.Any, setting: str, item: typing.Any):
|
||||
if not hasattr(obj, setting):
|
||||
setattr(obj, setting, [])
|
||||
getattr(obj, setting).append(item)
|
||||
def hook(event, **kwargs):
|
||||
def hook(event: str, **kwargs):
|
||||
def _hook_func(func):
|
||||
_set_get_append(func, ModuleManager.BITBOT_HOOKS_MAGIC,
|
||||
_set_get_append(func, consts.BITBOT_HOOKS_MAGIC,
|
||||
{"event": event, "kwargs": kwargs})
|
||||
return func
|
||||
return _hook_func
|
||||
def export(setting, value):
|
||||
def export(setting: str, value: typing.Any):
|
||||
def _export_func(module):
|
||||
_set_get_append(module, ModuleManager.BITBOT_EXPORTS_MAGIC,
|
||||
_set_get_append(module, consts.BITBOT_EXPORTS_MAGIC,
|
||||
{"setting": setting, "value": value})
|
||||
return module
|
||||
return _export_func
|
||||
|
||||
COMMENT_TYPES = ["#", "//"]
|
||||
def get_hashflags(filename):
|
||||
hashflags = {}
|
||||
with io.open(filename, mode="r", encoding="utf8") as f:
|
||||
for line in f:
|
||||
line = line.strip("\n")
|
||||
found = False
|
||||
for comment_type in COMMENT_TYPES:
|
||||
if line.startswith(comment_type):
|
||||
line = line.replace(comment_type, "", 1).lstrip()
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
break
|
||||
elif line.startswith("--"):
|
||||
hashflag, sep, value = line[2:].partition(" ")
|
||||
hashflags[hashflag] = value if sep else None
|
||||
return hashflags.items()
|
||||
|
||||
class Docstring(object):
|
||||
def __init__(self, description, items, var_items):
|
||||
self.description = description
|
||||
self.items = items
|
||||
self.var_items = var_items
|
||||
|
||||
def parse_docstring(s):
|
||||
description = ""
|
||||
last_item = None
|
||||
items = {}
|
||||
var_items = {}
|
||||
if s:
|
||||
for line in s.split("\n"):
|
||||
line = line.strip()
|
||||
|
||||
if line:
|
||||
if line[0] == ":":
|
||||
key, _, value = line[1:].partition(": ")
|
||||
last_item = key
|
||||
|
||||
if key in var_items:
|
||||
var_items[key].append(value)
|
||||
elif key in items:
|
||||
var_items[key] = [items.pop(key), value]
|
||||
else:
|
||||
items[key] = value
|
||||
else:
|
||||
if last_item:
|
||||
items[last_item] += " %s" % line
|
||||
else:
|
||||
if description:
|
||||
description += " "
|
||||
description += line
|
||||
return Docstring(description, items, var_items)
|
||||
|
||||
def top_10(items, convert_key=lambda x: x, value_format=lambda x: x):
|
||||
TOP_10_CALLABLE = typing.Callable[[typing.Any], typing.Any]
|
||||
def top_10(items: typing.List[typing.Any],
|
||||
convert_key: TOP_10_CALLABLE=lambda x: x,
|
||||
value_format: TOP_10_CALLABLE=lambda x: x):
|
||||
top_10 = sorted(items.keys())
|
||||
top_10 = sorted(top_10, key=items.get, reverse=True)[:10]
|
||||
|
||||
|
|
2
src/utils/consts.py
Normal file
2
src/utils/consts.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
BITBOT_HOOKS_MAGIC = "__bitbot_hooks"
|
||||
BITBOT_EXPORTS_MAGIC = "__bitbot_exports"
|
|
@ -1,4 +1,4 @@
|
|||
import re, signal, traceback, urllib.error, urllib.parse
|
||||
import re, signal, traceback, typing, urllib.error, urllib.parse
|
||||
import json as _json
|
||||
import bs4, requests
|
||||
|
||||
|
@ -18,9 +18,10 @@ class HTTPParsingException(HTTPException):
|
|||
def throw_timeout():
|
||||
raise HTTPTimeoutException()
|
||||
|
||||
def get_url(url, method="GET", get_params={}, post_data=None, headers={},
|
||||
json_data=None, code=False, json=False, soup=False, parser="lxml",
|
||||
fallback_encoding="utf8"):
|
||||
def get_url(url: str, method: str="GET", get_params: dict={},
|
||||
post_data: typing.Any=None, headers: dict={},
|
||||
json_data: typing.Any=None, code: bool=False, json: bool=False,
|
||||
soup: bool=False, parser: str="lxml", fallback_encoding: str="utf8"):
|
||||
|
||||
if not urllib.parse.urlparse(url).scheme:
|
||||
url = "http://%s" % url
|
||||
|
@ -66,6 +67,6 @@ def get_url(url, method="GET", get_params={}, post_data=None, headers={},
|
|||
else:
|
||||
return data
|
||||
|
||||
def strip_html(s):
|
||||
def strip_html(s: str) -> str:
|
||||
return bs4.BeautifulSoup(s, "lxml").get_text()
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import string, re
|
||||
import string, re, typing
|
||||
|
||||
ASCII_UPPER = string.ascii_uppercase
|
||||
ASCII_LOWER = string.ascii_lowercase
|
||||
|
@ -7,32 +7,36 @@ STRICT_RFC1459_LOWER = ASCII_LOWER+r'|{}'
|
|||
RFC1459_UPPER = STRICT_RFC1459_UPPER+"^"
|
||||
RFC1459_LOWER = STRICT_RFC1459_LOWER+"~"
|
||||
|
||||
def remove_colon(s):
|
||||
def remove_colon(s: str) -> str:
|
||||
if s.startswith(":"):
|
||||
s = s[1:]
|
||||
return s
|
||||
|
||||
MULTI_REPLACE_ITERABLE = typing.Iterable[str]
|
||||
# case mapping lowercase/uppcase logic
|
||||
def _multi_replace(s, chars1, chars2):
|
||||
def _multi_replace(s: str,
|
||||
chars1: typing.Iterable[str],
|
||||
chars2: typing.Iterable[str]) -> str:
|
||||
for char1, char2 in zip(chars1, chars2):
|
||||
s = s.replace(char1, char2)
|
||||
return s
|
||||
def lower(server, s):
|
||||
if server.case_mapping == "ascii":
|
||||
def lower(case_mapping: str, s: str) -> str:
|
||||
if case_mapping == "ascii":
|
||||
return _multi_replace(s, ASCII_UPPER, ASCII_LOWER)
|
||||
elif server.case_mapping == "rfc1459":
|
||||
elif case_mapping == "rfc1459":
|
||||
return _multi_replace(s, RFC1459_UPPER, RFC1459_LOWER)
|
||||
elif server.case_mapping == "strict-rfc1459":
|
||||
elif case_mapping == "strict-rfc1459":
|
||||
return _multi_replace(s, STRICT_RFC1459_UPPER, STRICT_RFC1459_LOWER)
|
||||
else:
|
||||
raise ValueError("unknown casemapping '%s'" % server.case_mapping)
|
||||
raise ValueError("unknown casemapping '%s'" % case_mapping)
|
||||
|
||||
# compare a string while respecting case mapping
|
||||
def equals(server, s1, s2):
|
||||
return lower(server, s1) == lower(server, s2)
|
||||
def equals(case_mapping: str, s1: str, s2: str) -> bool:
|
||||
return lower(case_mapping, s1) == lower(case_mapping, s2)
|
||||
|
||||
class IRCHostmask(object):
|
||||
def __init__(self, nickname, username, hostname, hostmask):
|
||||
def __init__(self, nickname: str, username: str, hostname: str,
|
||||
hostmask: str):
|
||||
self.nickname = nickname
|
||||
self.username = username
|
||||
self.hostname = hostname
|
||||
|
@ -42,24 +46,24 @@ class IRCHostmask(object):
|
|||
def __str__(self):
|
||||
return self.hostmask
|
||||
|
||||
def seperate_hostmask(hostmask):
|
||||
def seperate_hostmask(hostmask: str) -> IRCHostmask:
|
||||
hostmask = remove_colon(hostmask)
|
||||
nickname, _, username = hostmask.partition("!")
|
||||
username, _, hostname = username.partition("@")
|
||||
return IRCHostmask(nickname, username, hostname, hostmask)
|
||||
|
||||
|
||||
class IRCLine(object):
|
||||
def __init__(self, tags, prefix, command, args, arbitrary, last, server):
|
||||
def __init__(self, tags: dict, prefix: str, command: str,
|
||||
args: typing.List[str], arbitrary: typing.Optional[str],
|
||||
last: str):
|
||||
self.tags = tags
|
||||
self.prefix = prefix
|
||||
self.command = command
|
||||
self.args = args
|
||||
self.arbitrary = arbitrary
|
||||
self.last = last
|
||||
self.server = server
|
||||
|
||||
def parse_line(server, line):
|
||||
def parse_line(line: str) -> IRCLine:
|
||||
tags = {}
|
||||
prefix = None
|
||||
command = None
|
||||
|
@ -81,7 +85,7 @@ def parse_line(server, line):
|
|||
args = line.split(" ")
|
||||
last = arbitrary or args[-1]
|
||||
|
||||
return IRCLine(tags, prefix, command, args, arbitrary, last, server)
|
||||
return IRCLine(tags, prefix, command, args, arbitrary, last)
|
||||
|
||||
COLOR_WHITE, COLOR_BLACK, COLOR_BLUE, COLOR_GREEN = 0, 1, 2, 3
|
||||
COLOR_RED, COLOR_BROWN, COLOR_PURPLE, COLOR_ORANGE = 4, 5, 6, 7
|
||||
|
@ -94,20 +98,20 @@ FONT_BOLD, FONT_ITALIC, FONT_UNDERLINE, FONT_INVERT = ("\x02", "\x1D",
|
|||
FONT_COLOR, FONT_RESET = "\x03", "\x0F"
|
||||
REGEX_COLOR = re.compile("%s\d\d(?:,\d\d)?" % FONT_COLOR)
|
||||
|
||||
def color(s, foreground, background=None):
|
||||
def color(s: str, foreground: str, background: str=None) -> str:
|
||||
foreground = str(foreground).zfill(2)
|
||||
if background:
|
||||
background = str(background).zfill(2)
|
||||
return "%s%s%s%s%s" % (FONT_COLOR, foreground,
|
||||
"" if not background else ",%s" % background, s, FONT_COLOR)
|
||||
|
||||
def bold(s):
|
||||
def bold(s: str) -> str:
|
||||
return "%s%s%s" % (FONT_BOLD, s, FONT_BOLD)
|
||||
|
||||
def underline(s):
|
||||
def underline(s: str) -> str:
|
||||
return "%s%s%s" % (FONT_UNDERLINE, s, FONT_UNDERLINE)
|
||||
|
||||
def strip_font(s):
|
||||
def strip_font(s: str) -> str:
|
||||
s = s.replace(FONT_BOLD, "")
|
||||
s = s.replace(FONT_ITALIC, "")
|
||||
s = REGEX_COLOR.sub("", s)
|
||||
|
|
57
src/utils/parse.py
Normal file
57
src/utils/parse.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
import io, typing
|
||||
|
||||
COMMENT_TYPES = ["#", "//"]
|
||||
def hashflags(filename: str) -> typing.List[typing.Tuple[str, str]]:
|
||||
hashflags = {}
|
||||
with io.open(filename, mode="r", encoding="utf8") as f:
|
||||
for line in f:
|
||||
line = line.strip("\n")
|
||||
found = False
|
||||
for comment_type in COMMENT_TYPES:
|
||||
if line.startswith(comment_type):
|
||||
line = line.replace(comment_type, "", 1).lstrip()
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
break
|
||||
elif line.startswith("--"):
|
||||
hashflag, sep, value = line[2:].partition(" ")
|
||||
hashflags[hashflag] = value if sep else None
|
||||
return list(hashflags.items())
|
||||
|
||||
class Docstring(object):
|
||||
def __init__(self, description: str, items: dict, var_items: dict):
|
||||
self.description = description
|
||||
self.items = items
|
||||
self.var_items = var_items
|
||||
|
||||
def docstring(s: str) -> Docstring:
|
||||
description = ""
|
||||
last_item = None
|
||||
items = {}
|
||||
var_items = {}
|
||||
if s:
|
||||
for line in s.split("\n"):
|
||||
line = line.strip()
|
||||
|
||||
if line:
|
||||
if line[0] == ":":
|
||||
key, _, value = line[1:].partition(": ")
|
||||
last_item = key
|
||||
|
||||
if key in var_items:
|
||||
var_items[key].append(value)
|
||||
elif key in items:
|
||||
var_items[key] = [items.pop(key), value]
|
||||
else:
|
||||
items[key] = value
|
||||
else:
|
||||
if last_item:
|
||||
items[last_item] += " %s" % line
|
||||
else:
|
||||
if description:
|
||||
description += " "
|
||||
description += line
|
||||
return Docstring(description, items, var_items)
|
||||
|
Loading…
Reference in a new issue