Add type/return hints throughout src/ and, in doing so, fix some cyclical

references.
This commit is contained in:
jesopo 2018-10-30 14:58:48 +00:00
parent 705daaf9bb
commit e07553c362
22 changed files with 605 additions and 516 deletions

View file

@ -27,7 +27,7 @@ class Module(ModuleManager.BaseModule):
if not event["user"].last_karma or (time.time()-event["user"
].last_karma) >= KARMA_DELAY_SECONDS:
target = match.group(1).strip()
if utils.irc.lower(event["server"], target
if utils.irc.lower(event["server"].case_mapping, target
) == event["user"].name:
if verbose:
self.events.on("send.stderr").call(

View file

@ -542,7 +542,8 @@ class Module(ModuleManager.BaseModule):
# we need a registered nickname for this channel
@utils.hook("raw.477", default_event=True)
def handle_477(self, event):
channel_name = utils.irc.lower(event["server"], event["args"][1])
channel_name = utils.irc.lower(event["server"].case_mapping,
event["args"][1])
if channel_name in event["server"]:
key = event["server"].attempted_join[channel_name]
self.timers.add("rejoin", 5, channel_name=channe_name, key=key,

View file

@ -11,12 +11,16 @@ REGEX_SED = re.compile("^s/")
"help": "Disable/Enable sed only looking at the messages sent by the user",
"validate": utils.bool_or_none})
class Module(ModuleManager.BaseModule):
def _closest_setting(self, event, setting, default):
return event["channel"].get_setting(setting,
event["server"].get_setting(setting, default))
@utils.hook("received.message.channel")
def channel_message(self, event):
sed_split = re.split(REGEX_SPLIT, event["message"], 3)
if event["message"].startswith("s/") and len(sed_split) > 2:
if event["action"] or not utils.get_closest_setting(
event, "sed", False):
if event["action"] or not self._closest_setting(event, "sed",
False):
return
regex_flags = 0
@ -48,9 +52,8 @@ class Module(ModuleManager.BaseModule):
return
replace = sed_split[2].replace("\\/", "/")
for_user = event["user"].nickname if utils.get_closest_setting(
event, "sed-sender-only", False
) else None
for_user = event["user"].nickname if self._closest_setting(event,
"sed-sender-only", False) else None
line = event["channel"].buffer.find(pattern, from_self=False,
for_user=for_user, not_pattern=REGEX_SED)
if line:

View file

@ -1,21 +1,21 @@
import time, uuid
import time, typing, uuid
class Cache(object):
def __init__(self):
self._items = {}
self._item_to_id = {}
def cache(self, item):
def cache(self, item: typing.Any) -> str:
return self._cache(item, None)
def temporary_cache(self, item, timeout):
def temporary_cache(self, item: typing.Any, timeout: float)-> str:
return self._cache(item, timeout)
def _cache(self, item, timeout):
def _cache(self, item: typing.Any, timeout: float) -> str:
id = str(uuid.uuid4())
self._items[id] = [item, time.monotonic()+timeout]
self._item_to_id[item] = id
return id
def next_expiration(self):
def next_expiration(self) -> float:
expirations = [self._items[id][1] for id in self._items]
expirations = list(filter(None, expirations))
if not expirations:
@ -35,17 +35,17 @@ class Cache(object):
del self._items[id]
del self._item_to_id[item]
def has_item(self, item):
def has_item(self, item: typing.Any) -> bool:
return item in self._item_to_id
def get(self, id):
def get(self, id: str) -> typing.Any:
item, expiration = self._items[id]
return item
def get_expiration(self, item):
def get_expiration(self, item: typing.Any) -> float:
id = self._item_to_id[item]
item, expiration = self._items[id]
return expiration
def until_expiration(self, item):
def until_expiration(self, item: typing.Any) -> float:
expiration = self.get_expiration(item)
return expiration-time.monotonic()

View file

@ -1,7 +1,7 @@
import configparser, os
import configparser, os, typing
class Config(object):
def __init__(self, location):
def __init__(self, location: str):
self.location = location
self._config = {}
self.load()
@ -13,10 +13,10 @@ class Config(object):
parser.read_string(config_file.read())
self._config = dict(parser["bot"].items())
def __getitem__(self, key):
def __getitem__(self, key: str) -> typing.Any:
return self._config[key]
def get(self, key, default=None):
def get(self, key: str, default: typing.Any=None) -> typing.Any:
return self._config.get(key, default)
def __contains__(self, key):
def __contains__(self, key: str) -> bool:
return key in self._config

View file

@ -1,12 +1,14 @@
import json, os, sqlite3, threading, time
import json, os, sqlite3, threading, time, typing
from src import Logging
class Table(object):
def __init__(self, database):
self.database = database
class Servers(Table):
def add(self, alias, hostname, port, password, ipv4, tls, bindhost,
nickname, username=None, realname=None):
def add(self, alias: str, hostname: str, port: int, password: str,
ipv4: bool, tls: bool, bindhost: str,
nickname: str, username: str=None, realname: str=None):
username = username or nickname
realname = realname or nickname
self.database.execute(
@ -18,7 +20,7 @@ class Servers(Table):
def get_all(self):
return self.database.execute_fetchall(
"SELECT server_id, alias FROM servers")
def get(self, id):
def get(self, id: int):
return self.database.execute_fetchone(
"""SELECT server_id, alias, hostname, port, password, ipv4,
tls, bindhost, nickname, username, realname FROM servers WHERE
@ -26,46 +28,46 @@ class Servers(Table):
[id])
class Channels(Table):
def add(self, server_id, name):
def add(self, server_id: int, name: str):
self.database.execute("""INSERT OR IGNORE INTO channels
(server_id, name) VALUES (?, ?)""",
[server_id, name.lower()])
def delete(self, channel_id):
def delete(self, channel_id: int):
self.database.execute("DELETE FROM channels WHERE channel_id=?",
[channel_id])
def get_id(self, server_id, name):
def get_id(self, server_id: int, name: str):
value = self.database.execute_fetchone("""SELECT channel_id FROM
channels WHERE server_id=? AND name=?""",
[server_id, name.lower()])
return value if value == None else value[0]
class Users(Table):
def add(self, server_id, nickname):
def add(self, server_id: int, nickname: str):
self.database.execute("""INSERT OR IGNORE INTO users
(server_id, nickname) VALUES (?, ?)""",
[server_id, nickname.lower()])
def delete(self, user_id):
def delete(self, user_id: int):
self.database.execute("DELETE FROM users WHERE user_id=?",
[user_id])
def get_id(self, server_id, nickname):
def get_id(self, server_id: int, nickname: str):
value = self.database.execute_fetchone("""SELECT user_id FROM
users WHERE server_id=? and nickname=?""",
[server_id, nickname.lower()])
return value if value == None else value[0]
class BotSettings(Table):
def set(self, setting, value):
def set(self, setting: str, value: typing.Any):
self.database.execute(
"INSERT OR REPLACE INTO bot_settings VALUES (?, ?)",
[setting.lower(), json.dumps(value)])
def get(self, setting, default=None):
def get(self, setting: str, default: typing.Any=None):
value = self.database.execute_fetchone(
"SELECT value FROM bot_settings WHERE setting=?",
[setting.lower()])
if value:
return json.loads(value[0])
return default
def find(self, pattern, default=[]):
def find(self, pattern: str, default: typing.Any=[]):
values = self.database.execute_fetchall(
"SELECT setting, value FROM bot_settings WHERE setting LIKE ?",
[pattern.lower()])
@ -74,19 +76,19 @@ class BotSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
def find_prefix(self, prefix, default=[]):
def find_prefix(self, prefix: str, default: typing.Any=[]):
return self.find("%s%%" % prefix, default)
def delete(self, setting):
def delete(self, setting: str):
self.database.execute(
"DELETE FROM bot_settings WHERE setting=?",
[setting.lower()])
class ServerSettings(Table):
def set(self, server_id, setting, value):
def set(self, server_id: int, setting: str, value: typing.Any):
self.database.execute(
"INSERT OR REPLACE INTO server_settings VALUES (?, ?, ?)",
[server_id, setting.lower(), json.dumps(value)])
def get(self, server_id, setting, default=None):
def get(self, server_id: int, setting: str, default: typing.Any=None):
value = self.database.execute_fetchone(
"""SELECT value FROM server_settings WHERE
server_id=? AND setting=?""",
@ -94,7 +96,7 @@ class ServerSettings(Table):
if value:
return json.loads(value[0])
return default
def find(self, server_id, pattern, default=[]):
def find(self, server_id: int, pattern: str, default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT setting, value FROM server_settings WHERE
server_id=? AND setting LIKE ?""",
@ -104,26 +106,26 @@ class ServerSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
def find_prefix(self, server_id, prefix, default=[]):
def find_prefix(self, server_id: int, prefix: str, default: typing.Any=[]):
return self.find_server_settings(server_id, "%s%%" % prefix, default)
def delete(self, server_id, setting):
def delete(self, server_id: int, setting: str):
self.database.execute(
"DELETE FROM server_settings WHERE server_id=? AND setting=?",
[server_id, setting.lower()])
class ChannelSettings(Table):
def set(self, channel_id, setting, value):
def set(self, channel_id: int, setting: str, value: typing.Any):
self.database.execute(
"INSERT OR REPLACE INTO channel_settings VALUES (?, ?, ?)",
[channel_id, setting.lower(), json.dumps(value)])
def get(self, channel_id, setting, default=None):
def get(self, channel_id: int, setting: str, default: typing.Any=None):
value = self.database.execute_fetchone(
"""SELECT value FROM channel_settings WHERE
channel_id=? AND setting=?""", [channel_id, setting.lower()])
if value:
return json.loads(value[0])
return default
def find(self, channel_id, pattern, default=[]):
def find(self, channel_id: int, pattern: str, default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT setting, value FROM channel_settings WHERE
channel_id=? setting LIKE '?'""", [channel_id, pattern.lower()])
@ -132,15 +134,15 @@ class ChannelSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
def find_prefix(self, channel_id, prefix, default=[]):
def find_prefix(self, channel_id: int, prefix: str, default: typing.Any=[]):
return self.find_channel_settings(channel_id, "%s%%" % prefix,
default)
def delete(self, channel_id, setting):
def delete(self, channel_id: int, setting: str):
self.database.execute(
"""DELETE FROM channel_settings WHERE channel_id=?
AND setting=?""", [channel_id, setting.lower()])
def find_by_setting(self, setting, default=[]):
def find_by_setting(self, setting: str, default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT channels.server_id, channels.name,
channel_settings.value FROM channel_settings
@ -154,18 +156,19 @@ class ChannelSettings(Table):
return default
class UserSettings(Table):
def set(self, user_id, setting, value):
def set(self, user_id: int, setting: str, value: typing.Any):
self.database.execute(
"INSERT OR REPLACE INTO user_settings VALUES (?, ?, ?)",
[user_id, setting.lower(), json.dumps(value)])
def get(self, user_id, setting, default=None):
def get(self, user_id: int, setting: str, default: typing.Any=None):
value = self.database.execute_fetchone(
"""SELECT value FROM user_settings WHERE
user_id=? and setting=?""", [user_id, setting.lower()])
if value:
return json.loads(value[0])
return default
def find_all_by_setting(self, server_id, setting, default=[]):
def find_all_by_setting(self, server_id: int, setting: str,
default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT users.nickname, user_settings.value FROM
user_settings INNER JOIN users ON
@ -177,7 +180,7 @@ class UserSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
def find(self, user_id, pattern, default=[]):
def find(self, user_id: int, pattern: str, default: typing.Any=[]):
values = self.database.execute(
"""SELECT setting, value FROM user_settings WHERE
user_id=? AND setting LIKE '?'""", [user_id, pattern.lower()])
@ -186,20 +189,22 @@ class UserSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
def find_prefix(self, user_id, prefix, default=[]):
def find_prefix(self, user_id: int, prefix: str, default: typing.Any=[]):
return self.find_user_settings(user_id, "%s%%" % prefix, default)
def delete(self, user_id, setting):
def delete(self, user_id: int, setting: str):
self.database.execute(
"""DELETE FROM user_settings WHERE
user_id=? AND setting=?""", [user_id, setting.lower()])
class UserChannelSettings(Table):
def set(self, user_id, channel_id, setting, value):
def set(self, user_id: int, channel_id: int, setting: str,
value: typing.Any):
self.database.execute(
"""INSERT OR REPLACE INTO user_channel_settings VALUES
(?, ?, ?, ?)""",
[user_id, channel_id, setting.lower(), json.dumps(value)])
def get(self, user_id, channel_id, setting, default=None):
def get(self, user_id: int, channel_id: int, setting: str,
default: typing.Any=None):
value = self.database.execute_fetchone(
"""SELECT value FROM user_channel_settings WHERE
user_id=? AND channel_id=? AND setting=?""",
@ -207,7 +212,8 @@ class UserChannelSettings(Table):
if value:
return json.loads(value[0])
return default
def find(self, user_id, channel_id, pattern, default=[]):
def find(self, user_id: int, channel_id: int, pattern: str,
default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT setting, value FROM user_channel_settings WHERE
user_id=? AND channel_id=? AND setting LIKE '?'""",
@ -217,10 +223,12 @@ class UserChannelSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
def find_prefix(self, user_id, channel_id, prefix, default=[]):
def find_prefix(self, user_id: int, channel_id: int, prefix: str,
default: typing.Any=[]):
return self.find_user_settings(user_id, channel_id, "%s%%" % prefix,
default)
def find_by_setting(self, user_id, setting, default=[]):
def find_by_setting(self, user_id: int, setting: str,
default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT channels.name, user_channel_settings.value FROM
user_channel_settings INNER JOIN channels ON
@ -232,7 +240,8 @@ class UserChannelSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
def find_all_by_setting(self, server_id, setting, default=[]):
def find_all_by_setting(self, server_id: int, setting: str,
default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT channels.name, users.nickname,
user_channel_settings.value FROM
@ -246,14 +255,14 @@ class UserChannelSettings(Table):
values[i] = value[0], value[1], json.loads(value[2])
return values
return default
def delete(self, user_id, channel_id, setting):
def delete(self, user_id: int, channel_id: int, setting: str):
self.database.execute(
"""DELETE FROM user_channel_settings WHERE
user_id=? AND channel_id=? AND setting=?""",
[user_id, channel_id, setting.lower()])
class Database(object):
def __init__(self, log, location):
def __init__(self, log: "Logging.Log", location: str):
self.log = log
self.location = location
self.database = sqlite3.connect(self.location,
@ -284,7 +293,9 @@ class Database(object):
self._cursor = self.database.cursor()
return self._cursor
def _execute_fetch(self, query, fetch_func, params=[]):
def _execute_fetch(self, query: str,
fetch_func: typing.Callable[[sqlite3.Cursor], typing.Any],
params: typing.List=[]):
printable_query = " ".join(query.split())
self.log.trace("executing query: \"%s\" (params: %s)",
[printable_query, params])
@ -299,16 +310,16 @@ class Database(object):
self.log.trace("executed in %fms", [total_milliseconds])
return value
def execute_fetchall(self, query, params=[]):
def execute_fetchall(self, query: str, params: typing.List=[]):
return self._execute_fetch(query,
lambda cursor: cursor.fetchall(), params)
def execute_fetchone(self, query, params=[]):
def execute_fetchone(self, query: str, params: typing.List=[]):
return self._execute_fetch(query,
lambda cursor: cursor.fetchone(), params)
def execute(self, query, params=[]):
def execute(self, query: str, params: typing.List=[]):
return self._execute_fetch(query, lambda cursor: None, params)
def has_table(self, table_name):
def has_table(self, table_name: str):
result = self.execute_fetchone("""SELECT COUNT(*) FROM
sqlite_master WHERE type='table' AND name=?""",
[table_name])

View file

@ -1,5 +1,5 @@
import itertools, time, traceback
from src import utils
import itertools, time, traceback, typing
from src import Logging, utils
PRIORITY_URGENT = 0
PRIORITY_HIGH = 1
@ -11,94 +11,39 @@ DEFAULT_PRIORITY = PRIORITY_MEDIUM
DEFAULT_EVENT_DELIMITER = "."
DEFAULT_MULTI_DELIMITER = "|"
CALLBACK_TYPE = typing.Callable[["Event"], typing.Any]
class Event(object):
def __init__(self, name, **kwargs):
def __init__(self, name: str, **kwargs):
self.name = name
self.kwargs = kwargs
self.eaten = False
def __getitem__(self, key):
def __getitem__(self, key: str) -> typing.Any:
return self.kwargs[key]
def get(self, key, default=None):
def get(self, key: str, default=None) -> typing.Any:
return self.kwargs.get(key, default)
def __contains__(self, key):
def __contains__(self, key: str) -> bool:
return key in self.kwargs
def eat(self):
self.eaten = True
class EventCallback(object):
def __init__(self, function, priority, kwargs):
def __init__(self, function: CALLBACK_TYPE, priority: int, kwargs: dict):
self.function = function
self.priority = priority
self.kwargs = kwargs
self.docstring = utils.parse_docstring(function.__doc__)
self.docstring = utils.parse.docstring(function.__doc__)
def call(self, event):
def call(self, event: Event) -> typing.Any:
return self.function(event)
def get_kwarg(self, name, default=None):
def get_kwarg(self, name: str, default=None) -> typing.Any:
item = self.kwargs.get(name, default)
return item or self.docstring.items.get(name, default)
class MultipleEventHook(object):
def __init__(self):
self._event_hooks = set([])
def _add(self, event_hook):
self._event_hooks.add(event_hook)
def hook(self, function, **kwargs):
for event_hook in self._event_hooks:
event_hook.hook(function, **kwargs)
def call_limited(self, maximum, **kwargs):
returns = []
for event_hook in self._event_hooks:
returns.append(event_hook.call_limited(maximum, **kwargs))
return returns
def call(self, **kwargs):
returns = []
for event_hook in self._event_hooks:
returns.append(event_hook.call(**kwargs))
return returns
class EventHookContext(object):
def __init__(self, parent, context):
self._parent = parent
self.context = context
def hook(self, function, priority=DEFAULT_PRIORITY, replay=False,
**kwargs):
return self._parent._context_hook(self.context, function, priority,
replay, kwargs)
def unhook(self, callback):
self._parent.unhook(callback)
def on(self, subevent, *extra_subevents,
delimiter=DEFAULT_EVENT_DELIMITER):
return self._parent._context_on(self.context, subevent,
extra_subevents, delimiter)
def call_for_result(self, default=None, **kwargs):
return self._parent.call_for_result(default, **kwargs)
def assure_call(self, **kwargs):
self._parent.assure_call(**kwargs)
def call(self, **kwargs):
return self._parent.call(**kwargs)
def call_limited(self, maximum, **kwargs):
return self._parent.call_limited(maximum, **kwargs)
def call_unsafe_for_result(self, default=None, **kwargs):
return self._parent.call_unsafe_for_result(default, **kwargs)
def call_unsafe(self, **kwargs):
return self._parent.call_unsafe(**kwargs)
def call_unsafe_limited(self, maximum, **kwargs):
return self._parent.call_unsafe_limited(maximum, **kwargs)
def get_hooks(self):
return self._parent.get_hooks()
def get_children(self):
return self._parent.get_children()
class EventHook(object):
def __init__(self, log, name=None, parent=None):
def __init__(self, log: Logging.Log, name: str = None,
parent: "EventHook" = None):
self.log = log
self.name = name
self.parent = parent
@ -107,10 +52,10 @@ class EventHook(object):
self._stored_events = []
self._context_hooks = {}
def _make_event(self, kwargs):
def _make_event(self, kwargs: dict) -> Event:
return Event(self._get_path(), **kwargs)
def _get_path(self):
def _get_path(self) -> str:
path = []
parent = self
while not parent == None and not parent.name == None:
@ -118,15 +63,17 @@ class EventHook(object):
parent = parent.parent
return DEFAULT_EVENT_DELIMITER.join(path[::-1])
def new_context(self, context):
def new_context(self, context: str) -> "EventHookContext":
return EventHookContext(self, context)
def hook(self, function, priority=DEFAULT_PRIORITY, replay=False,
**kwargs):
def hook(self, function: CALLBACK_TYPE, priority: int = DEFAULT_PRIORITY,
replay: bool = False, **kwargs) -> EventCallback:
return self._hook(function, None, priority, replay, kwargs)
def _context_hook(self, context, function, priority, replay, kwargs):
def _context_hook(self, context: str, function: CALLBACK_TYPE,
priority: int, replay: bool, kwargs: dict) -> EventCallback:
return self._hook(function, context, priority, replay, kwargs)
def _hook(self, function, context, priority, replay, kwargs):
def _hook(self, function: CALLBACK_TYPE, context: str, priority: int,
replay: bool, kwargs: dict) -> EventCallback:
callback = EventCallback(function, priority, kwargs)
if context == None:
@ -142,7 +89,7 @@ class EventHook(object):
self._stored_events = None
return callback
def unhook(self, callback):
def unhook(self, callback: "EventHook"):
if callback in self._hooks:
self._hooks.remove(callback)
@ -155,7 +102,8 @@ class EventHook(object):
for context in empty:
del self._context_hooks[context]
def _make_multiple_hook(self, source, context, events):
def _make_multiple_hook(self, source: "EventHook", context: str,
events: typing.List[str]) -> "MultipleEventHook":
multiple_event_hook = MultipleEventHook()
for event in events:
event_hook = source.get_child(event)
@ -164,13 +112,15 @@ class EventHook(object):
multiple_event_hook._add(event_hook)
return multiple_event_hook
def on(self, subevent, *extra_subevents,
delimiter=DEFAULT_EVENT_DELIMITER):
def on(self, subevent: str, *extra_subevents,
delimiter: int = DEFAULT_EVENT_DELIMITER) -> "EventHook":
return self._on(subevent, extra_subevents, None, delimiter)
def _context_on(self, context, subevent, extra_subevents,
delimiter=DEFAULT_EVENT_DELIMITER):
def _context_on(self, context: str, subevent: str,
extra_subevents: typing.List[str],
delimiter: str = DEFAULT_EVENT_DELIMITER) -> "EventHook":
return self._on(subevent, extra_subevents, context, delimiter)
def _on(self, subevent, extra_subevents, context, delimiter):
def _on(self, subevent: str, extra_subevents: typing.List[str],
context: str, delimiter: str) -> "EventHook":
if delimiter in subevent:
event_chain = subevent.split(delimiter)
event_obj = self
@ -193,26 +143,28 @@ class EventHook(object):
child = child.new_context(context)
return child
def call_for_result(self, default=None, **kwargs):
def call_for_result(self, default=None, **kwargs) -> typing.Any:
return (self.call_limited(1, **kwargs) or [default])[0]
def assure_call(self, **kwargs):
if not self._stored_events == None:
self._stored_events.append(kwargs)
else:
self._call(kwargs, True, None)
def call(self, **kwargs):
def call(self, **kwargs) -> typing.List[typing.Any]:
return self._call(kwargs, True, None)
def call_limited(self, maximum, **kwargs):
def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]:
return self._call(kwargs, True, None)
def call_unsafe_for_result(self, default=None, **kwargs):
def call_unsafe_for_result(self, default=None, **kwargs) -> typing.Any:
return (self.call_unsafe_limited(1, **kwargs) or [default])[0]
def call_unsafe(self, **kwargs):
def call_unsafe(self, **kwargs) -> typing.List[typing.Any]:
return self._call(kwargs, False, None)
def call_unsafe_limited(self, maximum, **kwargs):
def call_unsafe_limited(self, maximum: int, **kwargs
) -> typing.List[typing.Any]:
return self._call(kwargs, False, maximum)
def _call(self, kwargs, safe, maximum):
def _call(self, kwargs: dict, safe: bool, maximum: int
) -> typing.List[typing.Any]:
event_path = self._get_path()
self.log.trace("calling event: \"%s\" (params: %s)",
[event_path, kwargs])
@ -240,13 +192,13 @@ class EventHook(object):
return returns
def get_child(self, child_name):
def get_child(self, child_name: str) -> "EventHook":
child_name_lower = child_name.lower()
if not child_name_lower in self._children:
self._children[child_name_lower] = EventHook(self.log,
child_name_lower, self)
return self._children[child_name_lower]
def remove_child(self, child_name):
def remove_child(self, child_name: str):
child_name_lower = child_name.lower()
if child_name_lower in self._children:
del self._children[child_name_lower]
@ -256,11 +208,11 @@ class EventHook(object):
self.parent.remove_child(self.name)
self.parent.check_purge()
def remove_context(self, context):
def remove_context(self, context: str):
del self._context_hooks[context]
def has_context(self, context):
def has_context(self, context: str) -> bool:
return context in self._context_hooks
def purge_context(self, context):
def purge_context(self, context: str):
if self.has_context(context):
self.remove_context(context)
@ -268,10 +220,69 @@ class EventHook(object):
child = self.get_child(child_name)
child.purge_context(context)
def get_hooks(self):
def get_hooks(self) -> typing.List[EventCallback]:
return sorted(self._hooks + sum(self._context_hooks.values(), []),
key=lambda e: e.priority)
def get_children(self):
def get_children(self) -> typing.List["EventHook"]:
return list(self._children.keys())
def is_empty(self):
def is_empty(self) -> bool:
return len(self.get_hooks() + self.get_children()) == 0
class MultipleEventHook(object):
def __init__(self):
self._event_hooks = set([])
def _add(self, event_hook: EventHook):
self._event_hooks.add(event_hook)
def hook(self, function: CALLBACK_TYPE, **kwargs):
for event_hook in self._event_hooks:
event_hook.hook(function, **kwargs)
def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]:
returns = []
for event_hook in self._event_hooks:
returns.append(event_hook.call_limited(maximum, **kwargs))
return returns
def call(self, **kwargs) -> typing.List[typing.Any]:
returns = []
for event_hook in self._event_hooks:
returns.append(event_hook.call(**kwargs))
return returns
class EventHookContext(object):
def __init__(self, parent, context):
self._parent = parent
self.context = context
def hook(self, function: CALLBACK_TYPE, priority: int = DEFAULT_PRIORITY,
replay: bool = False, **kwargs) -> EventCallback:
return self._parent._context_hook(self.context, function, priority,
replay, kwargs)
def unhook(self, callback: EventCallback):
self._parent.unhook(callback)
def on(self, subevent: str, *extra_subevents,
delimiter: str = DEFAULT_EVENT_DELIMITER) -> EventHook:
return self._parent._context_on(self.context, subevent,
extra_subevents, delimiter)
def call_for_result(self, default=None, **kwargs) -> typing.Any:
return self._parent.call_for_result(default, **kwargs)
def assure_call(self, **kwargs):
self._parent.assure_call(**kwargs)
def call(self, **kwargs) -> typing.List[typing.Any]:
return self._parent.call(**kwargs)
def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]:
return self._parent.call_limited(maximum, **kwargs)
def call_unsafe_for_result(self, default=None, **kwargs) -> typing.Any:
return self._parent.call_unsafe_for_result(default, **kwargs)
def call_unsafe(self, **kwargs) -> typing.List[typing.Any]:
return self._parent.call_unsafe(**kwargs)
def call_unsafe_limited(self, maximum: int, **kwargs
) -> typing.List[typing.Any]:
return self._parent.call_unsafe_limited(maximum, **kwargs)
def get_hooks(self) -> typing.List[EventCallback]:
return self._parent.get_hooks()
def get_children(self) -> typing.List[EventHook]:
return self._parent.get_children()

View file

@ -1,28 +1,18 @@
class ExportsContext(object):
def __init__(self, parent, context):
self._parent = parent
self.context = context
def add(self, setting, value):
self._parent._context_add(self.context, setting, value)
def get_all(self, setting):
return self._parent.get_all(setting)
import typing
class Exports(object):
def __init__(self):
self._exports = {}
self._context_exports = {}
def new_context(self, context):
def new_context(self, context: str) -> "ExportsContext":
return ExportsContext(self, context)
def add(self, setting, value):
def add(self, setting: str, value: typing.Any):
self._add(None, setting, value)
def _context_add(self, context, setting, value):
def _context_add(self, context: str, setting: str, value: typing.Any):
self._add(context, setting, value)
def _add(self, context, setting, value):
def _add(self, context: str, setting: str, value: typing.Any):
if context == None:
if not setting in self_exports:
self._exports[setting] = []
@ -34,11 +24,21 @@ class Exports(object):
self._context_exports[context][setting] = []
self._context_exports[context][setting].append(value)
def get_all(self, setting):
def get_all(self, setting: str) -> typing.List[typing.Any]:
return self._exports.get(setting, []) + sum([
exports.get(setting, []) for exports in
self._context_exports.values()], [])
def purge_context(self, context):
def purge_context(self, context: str):
if context in self._context_exports:
del self._context_exports[context]
class ExportsContext(object):
def __init__(self, parent: Exports, context: str):
self._parent = parent
self.context = context
def add(self, setting: str, value: typing.Any):
self._parent._context_add(self.context, setting, value)
def get_all(self, setting: str) -> typing.List[typing.Any]:
return self._parent.get_all(setting)

View file

@ -1,4 +1,4 @@
import os, select, socket, sys, threading, time, traceback, uuid
import os, select, socket, sys, threading, time, traceback, typing, uuid
from src import EventManager, Exports, IRCServer, Logging, ModuleManager
from src import Socket, utils
@ -28,14 +28,15 @@ class Bot(object):
self._trigger_functions = []
def trigger(self, func=None):
def trigger(self, func: typing.Callable[[], typing.Any]=None):
self.lock.acquire()
if func:
self._trigger_functions.append(func)
self._trigger_client.send(b"TRIGGER")
self.lock.release()
def add_server(self, server_id, connect=True):
def add_server(self, server_id: int, connect: bool = True
) -> typing.Optional[IRCServer.Server]:
(_, alias, hostname, port, password, ipv4, tls, bindhost, nickname,
username, realname) = self.database.servers.get(server_id)
@ -49,20 +50,20 @@ class Bot(object):
self.connect(new_server)
return new_server
def add_socket(self, sock):
def add_socket(self, sock: socket.socket):
self.other_sockets[sock.fileno()] = sock
self.poll.register(sock.fileno(), select.EPOLLIN)
def remove_socket(self, sock):
def remove_socket(self, sock: socket.socket):
del self.other_sockets[sock.fileno()]
self.poll.unregister(sock.fileno())
def get_server(self, id):
def get_server(self, id: int) -> typing.Optional[IRCServer.Server]:
for server in self.servers.values():
if server.id == id:
return server
def connect(self, server):
def connect(self, server: IRCServer.Server) -> bool:
try:
server.connect()
except:
@ -73,7 +74,7 @@ class Bot(object):
self.poll.register(server.fileno(), select.EPOLLOUT)
return True
def next_send(self):
def next_send(self) -> typing.Optional[float]:
next = None
for server in self.servers.values():
timeout = server.send_throttle_timeout()
@ -81,7 +82,7 @@ class Bot(object):
next = timeout
return next
def next_ping(self):
def next_ping(self) -> typing.Optional[float]:
timeouts = []
for server in self.servers.values():
timeout = server.until_next_ping()
@ -90,7 +91,8 @@ class Bot(object):
if not timeouts:
return None
return min(timeouts)
def next_read_timeout(self):
def next_read_timeout(self) -> typing.Optional[float]:
timeouts = []
for server in self.servers.values():
timeouts.append(server.until_read_timeout())
@ -98,7 +100,7 @@ class Bot(object):
return None
return min(timeouts)
def get_poll_timeout(self):
def get_poll_timeout(self) -> float:
timeouts = []
timeouts.append(self._timers.next())
timeouts.append(self.next_send())
@ -107,15 +109,15 @@ class Bot(object):
timeouts.append(self.cache.next_expiration())
return min([timeout for timeout in timeouts if not timeout == None])
def register_read(self, server):
def register_read(self, server: IRCServer.Server):
self.poll.modify(server.fileno(), select.EPOLLIN)
def register_write(self, server):
def register_write(self, server: IRCServer.Server):
self.poll.modify(server.fileno(), select.EPOLLOUT)
def register_both(self, server):
def register_both(self, server: IRCServer.Server):
self.poll.modify(server.fileno(),
select.EPOLLIN|select.EPOLLOUT)
def disconnect(self, server):
def disconnect(self, server: IRCServer.Server):
try:
self.poll.unregister(server.fileno())
except FileNotFoundError:
@ -123,23 +125,25 @@ class Bot(object):
del self.servers[server.fileno()]
@utils.hook("timer.reconnect")
def reconnect(self, event):
def reconnect(self, event: EventManager.Event):
server = self.add_server(event["server_id"], False)
if self.connect(server):
self.servers[server.fileno()] = server
else:
event["timer"].redo()
def set_setting(self, setting, value):
def set_setting(self, setting: str, value: typing.Any):
self.database.bot_settings.set(setting, value)
def get_setting(self, setting, default=None):
def get_setting(self, setting: str, default: typing.Any=None) -> typing.Any:
return self.database.bot_settings.get(setting, default)
def find_settings(self, pattern, default=[]):
def find_settings(self, pattern: str, default: typing.Any=[]
) -> typing.List[typing.Any]:
return self.database.bot_settings.find(pattern, default)
def find_settings_prefix(self, prefix, default=[]):
def find_settings_prefix(self, prefix: str, default: typing.Any=[]
) -> typing.List[typing.Any]:
return self.database.bot_settings.find_prefix(
prefix, default)
def del_setting(self, setting):
def del_setting(self, setting: str):
self.database.bot_settings.delete(setting)
def run(self):

View file

@ -1,8 +1,9 @@
import re
from src import utils
import re, typing
from src import IRCBot, utils
class BufferLine(object):
def __init__(self, sender, message, action, tags, from_self, method):
def __init__(self, sender: str, message: str, action: bool, tags: dict,
from_self: bool, method: str):
self.sender = sender
self.message = message
self.action = action
@ -11,35 +12,39 @@ class BufferLine(object):
self.method = method
class Buffer(object):
def __init__(self, bot, server):
def __init__(self, bot: "IRCBot.Bot", server: "IRCServer.Server"):
self.bot = bot
self.server = server
self.lines = []
self.max_lines = 64
self._skip_next = False
def _add_message(self, sender, message, action, tags, from_self, method):
def _add_message(self, sender: str, message: str, action: bool, tags: dict,
from_self: bool, method: str):
if not self._skip_next:
line = BufferLine(sender, message, action, tags, from_self, method)
self.lines.insert(0, line)
if len(self.lines) > self.max_lines:
self.lines.pop()
self._skip_next = False
def add_message(self, sender, message, action, tags, from_self=False):
def add_message(self, sender: str, message: str, action: bool, tags: dict,
from_self: bool=False):
self._add_message(sender, message, action, tags, from_self, "PRIVMSG")
def add_notice(self, sender, message, tags, from_self=False):
def add_notice(self, sender: str, message: str, tags: dict,
from_self: bool=False):
self._add_message(sender, message, False, tags, from_self, "NOTICE")
def get(self, index=0, **kwargs):
def get(self, index: int=0, **kwargs) -> typing.Optional[BufferLine]:
from_self = kwargs.get("from_self", True)
for line in self.lines:
if line.from_self and not from_self:
continue
return line
def find(self, pattern, **kwargs):
def find(self, pattern: typing.Union[str, typing.Pattern[str]], **kwargs
) -> typing.Optional[BufferLine]:
from_self = kwargs.get("from_self", True)
for_user = kwargs.get("for_user", "")
for_user = utils.irc.lower(self.server, for_user
for_user = utils.irc.lower(self.server.case_mapping, for_user
) if for_user else None
not_pattern = kwargs.get("not_pattern", None)
for line in self.lines:
@ -48,8 +53,8 @@ class Buffer(object):
elif re.search(pattern, line.message):
if not_pattern and re.search(not_pattern, line.message):
continue
if for_user and not utils.irc.lower(self.server, line.sender
) == for_user:
if for_user and not utils.irc.lower(self.server.case_mapping,
line.sender) == for_user:
continue
return line
def skip_next(self):

View file

@ -1,9 +1,10 @@
import uuid
from src import IRCBuffer, IRCObject, utils
import typing, uuid
from src import IRCBot, IRCBuffer, IRCObject, IRCServer, IRCUser, utils
class Channel(IRCObject.Object):
def __init__(self, name, id, server, bot):
self.name = utils.irc.lower(server, name)
def __init__(self, name: str, id, server: "IRCServer.Server",
bot: "IRCBot.Bot"):
self.name = utils.irc.lower(server.case_mapping, name)
self.id = id
self.server = server
self.bot = bot
@ -18,23 +19,24 @@ class Channel(IRCObject.Object):
self.created_timestamp = None
self.buffer = IRCBuffer.Buffer(bot, server)
def __repr__(self):
def __repr__(self) -> str:
return "IRCChannel.Channel(%s|%s)" % (self.server.name, self.name)
def __str__(self):
def __str__(self) -> str:
return self.name
def set_topic(self, topic):
def set_topic(self, topic: str):
self.topic = topic
def set_topic_setter(self, nickname, username=None, hostname=None):
def set_topic_setter(self, nickname: str, username: str=None,
hostname: str=None):
self.topic_setter_nickname = nickname
self.topic_setter_username = username
self.topic_setter_hostname = hostname
def set_topic_time(self, unix_timestamp):
def set_topic_time(self, unix_timestamp: int):
self.topic_time = unix_timestamp
def add_user(self, user):
def add_user(self, user: IRCUser.User):
self.users.add(user)
def remove_user(self, user):
def remove_user(self, user: IRCUser.User):
self.users.remove(user)
for mode in list(self.modes.keys()):
if mode in self.server.prefix_modes and user in self.modes[mode]:
@ -43,10 +45,10 @@ class Channel(IRCObject.Object):
del self.modes[mode]
if user in self.user_modes:
del self.user_modes[user]
def has_user(self, user):
def has_user(self, user: IRCUser.User) -> bool:
return user in self.users
def add_mode(self, mode, arg=None):
def add_mode(self, mode: str, arg: str=None):
if not mode in self.modes:
self.modes[mode] = set([])
if arg:
@ -59,7 +61,7 @@ class Channel(IRCObject.Object):
self.user_modes[user].add(mode)
else:
self.modes[mode].add(arg.lower())
def remove_mode(self, mode, arg=None):
def remove_mode(self, mode: str, arg: str=None):
if not arg:
del self.modes[mode]
else:
@ -76,63 +78,70 @@ class Channel(IRCObject.Object):
self.modes[mode].discard(arg.lower())
if not len(self.modes[mode]):
del self.modes[mode]
def change_mode(self, remove, mode, arg=None):
def change_mode(self, remove: bool, mode: str, arg: str=None):
if remove:
self.remove_mode(mode, arg)
else:
self.add_mode(mode, arg)
def set_setting(self, setting, value):
def set_setting(self, setting: str, value: typing.Any):
self.bot.database.channel_settings.set(self.id, setting, value)
def get_setting(self, setting, default=None):
def get_setting(self, setting: str, default: typing.Any=None
) -> typing.Any:
return self.bot.database.channel_settings.get(self.id, setting,
default)
def find_settings(self, pattern, default=[]):
def find_settings(self, pattern: str, default: typing.Any=[]
) -> typing.List[typing.Any]:
return self.bot.database.channel_settings.find(self.id, pattern,
default)
def find_settings_prefix(self, prefix, default=[]):
def find_settings_prefix(self, prefix: str, default: typing.Any=[]
) -> typing.List[typing.Any]:
return self.bot.database.channel_settings.find_prefix(self.id,
prefix, default)
def del_setting(self, setting):
def del_setting(self, setting: str):
self.bot.database.channel_settings.delete(self.id, setting)
def set_user_setting(self, user_id, setting, value):
def set_user_setting(self, user_id: int, setting: str, value: typing.Any):
self.bot.database.user_channel_settings.set(user_id, self.id,
setting, value)
def get_user_setting(self, user_id, setting, default=None):
def get_user_setting(self, user_id: int, setting: str,
default: typing.Any=None) -> typing.Any:
return self.bot.database.user_channel_settings.get(user_id,
self.id, setting, default)
def find_user_settings(self, user_i, pattern, default=[]):
def find_user_settings(self, user_id: int, pattern: str,
default: typing.Any=[]) -> typing.List[typing.Any]:
return self.bot.database.user_channel_settings.find(user_id,
self.id, pattern, default)
def find_user_settings_prefix(self, user_id, prefix, default=[]):
def find_user_settings_prefix(self, user_id: int, prefix: str,
default: typing.Any=[]) -> typing.List[typing.Any]:
return self.bot.database.user_channel_settings.find_prefix(
user_id, self.id, prefix, default)
def del_user_setting(self, user_id, setting):
def del_user_setting(self, user_id: int, setting: str):
self.bot.database.user_channel_settings.delete(user_id, self.id,
setting)
def find_all_by_setting(self, setting, default=[]):
def find_all_by_setting(self, setting: str, default: typing.Any=[]
) -> typing.List[typing.Any]:
return self.bot.database.user_channel_settings.find_all_by_setting(
self.id, setting, default)
def send_message(self, text, prefix=None, tags={}):
def send_message(self, text: str, prefix: str=None, tags: dict={}):
self.server.send_message(self.name, text, prefix=prefix, tags=tags)
def send_notice(self, text, prefix=None, tags={}):
def send_notice(self, text: str, prefix: str=None, tags: dict={}):
self.server.send_notice(self.name, text, prefix=prefix, tags=tags)
def send_mode(self, mode=None, target=None):
def send_mode(self, mode: str=None, target: str=None):
self.server.send_mode(self.name, mode, target)
def send_kick(self, target, reason=None):
def send_kick(self, target: str, reason: str=None):
self.server.send_kick(self.name, target, reason)
def send_ban(self, hostmask):
def send_ban(self, hostmask: str):
self.server.send_mode(self.name, "+b", hostmask)
def send_unban(self, hostmask):
def send_unban(self, hostmask: str):
self.server.send_mode(self.name, "-b", hostmask)
def send_topic(self, topic):
def send_topic(self, topic: str):
self.server.send_topic(self.name, topic)
def send_part(self, reason=None):
def send_part(self, reason: str=None):
self.server.send_part(self.name, reason)
def mode_or_above(self, user, mode):
def mode_or_above(self, user: IRCUser.User, mode: str) -> bool:
mode_orders = list(self.server.prefix_modes)
mode_index = mode_orders.index(mode)
for mode in mode_orders[:mode_index+1]:
@ -140,8 +149,8 @@ class Channel(IRCObject.Object):
return True
return False
def has_mode(self, user, mode):
def has_mode(self, user: IRCUser.User, mode: str) -> bool:
return user in self.modes.get(mode, [])
def get_user_status(self, user):
def get_user_status(self, user: IRCUser.User) -> typing.Set:
return self.user_modes.get(user, [])

View file

@ -1,5 +1,5 @@
import collections, socket, ssl, sys, time
from src import IRCChannel, IRCObject, IRCUser, utils
import collections, socket, ssl, sys, time, typing
from src import EventManager, IRCBot, IRCChannel, IRCObject, IRCUser, utils
THROTTLE_LINES = 4
THROTTLE_SECONDS = 1
@ -7,8 +7,12 @@ READ_TIMEOUT_SECONDS = 120
PING_INTERVAL_SECONDS = 30
class Server(IRCObject.Object):
def __init__(self, bot, events, id, alias, hostname, port, password,
ipv4, tls, bindhost, nickname, username, realname):
def __init__(self,
bot: "IRCBot.Bot",
events: EventManager.EventHook,
id: int, alias: str, hostname: str, port: int, password: str,
ipv4: bool, tls: bool, bindhost: str,
nickname: str, username: str, realname: str):
self.connected = False
self.bot = bot
self.events = events
@ -121,77 +125,80 @@ class Server(IRCObject.Object):
except:
pass
def set_setting(self, setting, value):
def set_setting(self, setting: str, value: typing.Any):
self.bot.database.server_settings.set(self.id, setting,
value)
def get_setting(self, setting, default=None):
def get_setting(self, setting: str, default: typing.Any=None):
return self.bot.database.server_settings.get(self.id,
setting, default)
def find_settings(self, pattern, default=[]):
def find_settings(self, pattern: str, default: typing.Any=[]):
return self.bot.database.server_settings.find(self.id,
pattern, default)
def find_settings_prefix(self, prefix, default=[]):
def find_settings_prefix(self, prefix: str, default: typing.Any=[]):
return self.bot.database.server_settings.find_prefix(
self.id, prefix, default)
def del_setting(self, setting):
def del_setting(self, setting: str):
self.bot.database.server_settings.delete(self.id, setting)
def get_user_setting(self, nickname, setting, default=None):
def get_user_setting(self, nickname: str, setting: str,
default: typing.Any=None):
user_id = self.get_user_id(nickname)
return self.bot.database.user_settings.get(user_id, setting, default)
def set_user_setting(self, nickname, setting, value):
def set_user_setting(self, nickname: str, setting: str, value: typing.Any):
user_id = self.get_user_id(nickname)
self.bot.database.user_settings.set(user_id, setting, value)
def get_all_user_settings(self, setting, default=[]):
def get_all_user_settings(self, setting: str, default: typing.Any=[]):
return self.bot.database.user_settings.find_all_by_setting(
self.id, setting, default)
def find_all_user_channel_settings(self, setting, default=[]):
def find_all_user_channel_settings(self, setting: str,
default: typing.Any=[]):
return self.bot.database.user_channel_settings.find_all_by_setting(
self.id, setting, default)
def set_own_nickname(self, nickname):
def set_own_nickname(self, nickname: str):
self.nickname = nickname
self.nickname_lower = utils.irc.lower(self, nickname)
def is_own_nickname(self, nickname):
self.nickname_lower = utils.irc.lower(self.case_mapping, nickname)
def is_own_nickname(self, nickname: str):
return utils.irc.equals(self, nickname, self.nickname)
def add_own_mode(self, mode, arg=None):
def add_own_mode(self, mode: str, arg: str=None):
self.own_modes[mode] = arg
def remove_own_mode(self, mode):
def remove_own_mode(self, mode: str):
del self.own_modes[mode]
def change_own_mode(self, remove, mode, arg=None):
def change_own_mode(self, remove: bool, mode: str, arg: str=None):
if remove:
self.remove_own_mode(mode)
else:
self.add_own_mode(mode, arg)
def has_user(self, nickname):
return utils.irc.lower(self, nickname) in self.users
def get_user(self, nickname, create=True):
def has_user(self, nickname: str):
return utils.irc.lower(self.case_mapping, nickname) in self.users
def get_user(self, nickname: str, create: bool=True):
if not self.has_user(nickname) and create:
user_id = self.get_user_id(nickname)
new_user = IRCUser.User(nickname, user_id, self, self.bot)
self.events.on("new.user").call(user=new_user, server=self)
self.users[new_user.nickname_lower] = new_user
self.new_users.add(new_user)
return self.users.get(utils.irc.lower(self, nickname), None)
def get_user_id(self, nickname):
return self.users.get(utils.irc.lower(self.case_mapping, nickname),
None)
def get_user_id(self, nickname: str):
self.bot.database.users.add(self.id, nickname)
return self.bot.database.users.get_id(self.id, nickname)
def remove_user(self, user):
def remove_user(self, user: IRCUser.User):
del self.users[user.nickname_lower]
for channel in user.channels:
channel.remove_user(user)
def change_user_nickname(self, old_nickname, new_nickname):
user = self.users.pop(utils.irc.lower(self, old_nickname))
def change_user_nickname(self, old_nickname: str, new_nickname: str):
user = self.users.pop(utils.irc.lower(self.case_mapping, old_nickname))
user._id = self.get_user_id(new_nickname)
self.users[utils.irc.lower(self, new_nickname)] = user
def has_channel(self, channel_name):
self.users[utils.irc.lower(self.case_mapping, new_nickname)] = user
def has_channel(self, channel_name: str):
return channel_name[0] in self.channel_types and utils.irc.lower(
self, channel_name) in self.channels
def get_channel(self, channel_name):
self.case_mapping, channel_name) in self.channels
def get_channel(self, channel_name: str):
if not self.has_channel(channel_name):
channel_id = self.get_channel_id(channel_name)
new_channel = IRCChannel.Channel(channel_name, channel_id,
@ -199,15 +206,15 @@ class Server(IRCObject.Object):
self.events.on("new.channel").call(channel=new_channel,
server=self)
self.channels[new_channel.name] = new_channel
return self.channels[utils.irc.lower(self, channel_name)]
def get_channel_id(self, channel_name):
return self.channels[utils.irc.lower(self.case_mapping, channel_name)]
def get_channel_id(self, channel_name: str):
self.bot.database.channels.add(self.id, channel_name)
return self.bot.database.channels.get_id(self.id, channel_name)
def remove_channel(self, channel):
def remove_channel(self, channel: IRCChannel.Channel):
for user in channel.users:
user.part_channel(channel)
del self.channels[channel.name]
def parse_data(self, line):
def parse_data(self, line: str):
if not line:
return
self.events.on("raw").call_unsafe(server=self, line=line)
@ -271,7 +278,7 @@ class Server(IRCObject.Object):
def read_timed_out(self):
return self.until_read_timeout == 0
def send(self, data):
def send(self, data: str):
returned = self.events.on("preprocess.send").call_unsafe_for_result(
server=self, line=data)
line = returned or data
@ -314,16 +321,16 @@ class Server(IRCObject.Object):
time_left = time_left-now
return time_left
def send_user(self, username, realname):
def send_user(self, username: str, realname: str):
self.send("USER %s 0 * :%s" % (username, realname))
def send_nick(self, nickname):
def send_nick(self, nickname: str):
self.send("NICK %s" % nickname)
def send_capibility_ls(self):
self.send("CAP LS 302")
def queue_capability(self, capability):
def queue_capability(self, capability: str):
self._capability_queue.add(capability)
def queue_capabilities(self, capabilities):
def queue_capabilities(self, capabilities: typing.List[str]):
self._capability_queue.update(capabilities)
def send_capability_queue(self):
if self.has_capability_queue():
@ -332,46 +339,46 @@ class Server(IRCObject.Object):
self.send_capability_request(capabilities)
def has_capability_queue(self):
return bool(len(self._capability_queue))
def send_capability_request(self, capability):
def send_capability_request(self, capability: str):
self.send("CAP REQ :%s" % capability)
def send_capability_end(self):
self.send("CAP END")
def send_authenticate(self, text):
def send_authenticate(self, text: str):
self.send("AUTHENTICATE %s" % text)
def send_starttls(self):
self.send("STARTTLS")
def waiting_for_capabilities(self):
return bool(len(self._capabilities_waiting))
def wait_for_capability(self, capability):
def wait_for_capability(self, capability: str):
self._capabilities_waiting.add(capability)
def capability_done(self, capability):
def capability_done(self, capability: str):
self._capabilities_waiting.remove(capability)
if not self._capabilities_waiting:
self.send_capability_end()
def send_pass(self, password):
def send_pass(self, password: str):
self.send("PASS %s" % password)
def send_ping(self, nonce="hello"):
def send_ping(self, nonce: str="hello"):
self.send("PING :%s" % nonce)
def send_pong(self, nonce="hello"):
def send_pong(self, nonce: str="hello"):
self.send("PONG :%s" % nonce)
def try_rejoin(self, event):
def try_rejoin(self, event: EventManager.Event):
if event["server_id"] == self.id and event["channel_name"
] in self.attempted_join:
self.send_join(event["channel_name"], event["key"])
def send_join(self, channel_name, key=None):
def send_join(self, channel_name: str, key: str=None):
self.send("JOIN %s%s" % (channel_name,
"" if key == None else " %s" % key))
def send_part(self, channel_name, reason=None):
def send_part(self, channel_name: str, reason: str=None):
self.send("PART %s%s" % (channel_name,
"" if reason == None else " %s" % reason))
def send_quit(self, reason="Leaving"):
def send_quit(self, reason: str="Leaving"):
self.send("QUIT :%s" % reason)
def _tag_str(self, tags):
def _tag_str(self, tags: dict):
tag_str = ""
for tag, value in tags.items():
if tag_str:
@ -383,7 +390,8 @@ class Server(IRCObject.Object):
tag_str = "@%s " % tag_str
return tag_str
def send_message(self, target, message, prefix=None, tags={}):
def send_message(self, target: str, message: str, prefix: str=None,
tags: dict={}):
full_message = message if not prefix else prefix+message
self.send("%sPRIVMSG %s :%s" % (self._tag_str(tags), target,
full_message))
@ -408,7 +416,8 @@ class Server(IRCObject.Object):
message=full_message, message_split=full_message_split,
user=user, action=action, server=self)
def send_notice(self, target, message, prefix=None, tags={}):
def send_notice(self, target: str, message: str, prefix: str=None,
tags: dict={}):
full_message = message if not prefix else prefix+message
self.send("%sNOTICE %s :%s" % (self._tag_str(tags), target,
full_message))
@ -419,31 +428,31 @@ class Server(IRCObject.Object):
self.get_user(target).buffer.add_notice(None, message, tags,
True)
def send_mode(self, target, mode=None, args=None):
def send_mode(self, target: str, mode: str=None, args: str=None):
self.send("MODE %s%s%s" % (target, "" if mode == None else " %s" % mode,
"" if args == None else " %s" % args))
def send_topic(self, channel_name, topic):
def send_topic(self, channel_name: str, topic: str):
self.send("TOPIC %s :%s" % (channel_name, topic))
def send_kick(self, channel_name, target, reason=None):
def send_kick(self, channel_name: str, target: str, reason: str=None):
self.send("KICK %s %s%s" % (channel_name, target,
"" if reason == None else " :%s" % reason))
def send_names(self, channel_name):
def send_names(self, channel_name: str):
self.send("NAMES %s" % channel_name)
def send_list(self, search_for=None):
def send_list(self, search_for: str=None):
self.send(
"LIST%s" % "" if search_for == None else " %s" % search_for)
def send_invite(self, target, channel_name):
def send_invite(self, target: str, channel_name: str):
self.send("INVITE %s %s" % (target, channel_name))
def send_whois(self, target):
def send_whois(self, target: str):
self.send("WHOIS %s" % target)
def send_whowas(self, target, amount=None, server=None):
def send_whowas(self, target: str, amount: int=None, server: str=None):
self.send("WHOWAS %s%s%s" % (target,
"" if amount == None else " %s" % amount,
"" if server == None else " :%s" % server))
def send_who(self, filter=None):
def send_who(self, filter: str=None):
self.send("WHO%s" % ("" if filter == None else " %s" % filter))
def send_whox(self, mask, filter, fields, label=None):
def send_whox(self, mask: str, filter: str, fields: str, label: str=None):
self.send("WHO %s %s%%%s%s" % (mask, filter, fields,
","+label if label else ""))

View file

@ -1,8 +1,9 @@
import uuid
from src import IRCBuffer, IRCObject, utils
import typing, uuid
from src import IRCBot, IRCChannel, IRCBuffer, IRCObject, IRCServer, utils
class User(IRCObject.Object):
def __init__(self, nickname, id, server, bot):
def __init__(self, nickname: str, id: int, server: "IRCServer.Server",
bot: "IRCBot.Bot"):
self.server = server
self.set_nickname(nickname)
self._id = id
@ -20,46 +21,51 @@ class User(IRCObject.Object):
self.away = False
self.buffer = IRCBuffer.Buffer(bot, server)
def __repr__(self):
def __repr__(self) -> str:
return "IRCUser.User(%s|%s)" % (self.server.name, self.name)
def __str__(self):
def __str__(self) -> str:
return self.nickname
def get_id(self):
def get_id(self)-> int:
return (self.identified_account_id_override or
self.identified_account_id or self._id)
def get_identified_account(self):
def get_identified_account(self) -> str:
return (self.identified_account_override or self.identified_account)
def set_nickname(self, nickname):
def set_nickname(self, nickname: str):
self.nickname = nickname
self.nickname_lower = utils.irc.lower(self.server, nickname)
self.nickname_lower = utils.irc.lower(self.server.case_mapping,
nickname)
self.name = self.nickname_lower
def join_channel(self, channel):
def join_channel(self, channel: "IRCChannel.Channel"):
self.channels.add(channel)
def part_channel(self, channel):
def part_channel(self, channel: "IRCChannel.Channel"):
self.channels.remove(channel)
def set_setting(self, setting, value):
def set_setting(self, setting: str, value: typing.Any):
self.bot.database.user_settings.set(self.get_id(), setting, value)
def get_setting(self, setting, default=None):
def get_setting(self, setting: str, default: typing.Any=None) -> typing.Any:
return self.bot.database.user_settings.get(self.get_id(), setting,
default)
def find_settings(self, pattern, default=[]):
def find_settings(self, pattern: str, default: typing.Any=[]
) -> typing.List[typing.Any]:
return self.bot.database.user_settings.find(self.get_id(), pattern,
default)
def find_settings_prefix(self, prefix, default=[]):
def find_settings_prefix(self, prefix: str, default: typing.Any=[]
) -> typing.List[typing.Any]:
return self.bot.database.user_settings.find_prefix(self.get_id(),
prefix, default)
def del_setting(self, setting):
self.bot.database.user_settings.delete(self.get_id(), setting)
def get_channel_settings_per_setting(self, setting, default=[]):
def get_channel_settings_per_setting(self, setting: str,
default: typing.Any=[]) -> typing.List[typing.Any]:
return self.bot.database.user_channel_settings.find_by_setting(
self.get_id(), setting, default)
def send_message(self, message, prefix=None, tags={}):
def send_message(self, message: str, prefix: str=None, tags: dict={}):
self.server.send_message(self.nickname, message, prefix=prefix,
tags=tags)
def send_notice(self, text, prefix=None, tags={}):
def send_notice(self, text: str, prefix: str=None, tags: dict={}):
self.server.send_notice(self.nickname, text, prefix=prefix, tags=tags)
def send_ctcp_response(self, command, args):
def send_ctcp_response(self, command: str, args: str):
self.send_notice("\x01%s %s\x01" % (command, args))

View file

@ -1,4 +1,4 @@
import logging, logging.handlers, os, sys, time
import logging, logging.handlers, os, sys, time, typing
LEVELS = {
"trace": logging.DEBUG-1,
@ -23,7 +23,7 @@ class BitBotFormatter(logging.Formatter):
return s
class Log(object):
def __init__(self, level, location):
def __init__(self, level: str, location: str):
logging.addLevelName(LEVELS["trace"], "TRACE")
self.logger = logging.getLogger(__name__)
@ -49,17 +49,17 @@ class Log(object):
file_handler.setFormatter(formatter)
self.logger.addHandler(file_handler)
def trace(self, message, params, **kwargs):
def trace(self, message: str, params: typing.List, **kwargs):
self._log(message, params, LEVELS["trace"], kwargs)
def debug(self, message, params, **kwargs):
def debug(self, message: str, params: typing.List, **kwargs):
self._log(message, params, logging.DEBUG, kwargs)
def info(self, message, params, **kwargs):
def info(self, message: str, params: typing.List, **kwargs):
self._log(message, params, logging.INFO, kwargs)
def warn(self, message, params, **kwargs):
def warn(self, message: str, params: typing.List, **kwargs):
self._log(message, params, logging.WARN, kwargs)
def error(self, message, params, **kwargs):
def error(self, message: str, params: typing.List, **kwargs):
self._log(message, params, logging.ERROR, kwargs)
def critical(self, message, params, **kwargs):
def critical(self, message: str, params: typing.List, **kwargs):
self._log(message, params, logging.CRITICAL, kwargs)
def _log(self, message, params, level, kwargs):
def _log(self, message: str, params: typing.List, level: int, kwargs: dict):
self.logger.log(level, message, *params, **kwargs)

View file

@ -1,8 +1,5 @@
import gc, glob, imp, io, inspect, os, sys, uuid
from . import utils
BITBOT_HOOKS_MAGIC = "__bitbot_hooks"
BITBOT_EXPORTS_MAGIC = "__bitbot_exports"
import gc, glob, imp, io, inspect, os, sys, typing, uuid
from src import Config, EventManager, Exports, IRCBot, Logging, Timers, utils
class ModuleException(Exception):
pass
@ -22,7 +19,11 @@ class ModuleNotLoadedWarning(ModuleWarning):
pass
class BaseModule(object):
def __init__(self, bot, events, exports, timers):
def __init__(self,
bot: "IRCBot.Bot",
events: EventManager.EventHook,
exports: Exports.Exports,
timers: Timers.Timers):
self.bot = bot
self.events = events
self.exports = exports
@ -32,7 +33,13 @@ class BaseModule(object):
pass
class ModuleManager(object):
def __init__(self, events, exports, timers, config, log, directory):
def __init__(self,
events: EventManager.EventHook,
exports: Exports.Exports,
timers: Timers.Timers,
config: Config.Config,
log: Logging.Log,
directory: str):
self.events = events
self.exports = exports
self.config = config
@ -43,23 +50,24 @@ class ModuleManager(object):
self.modules = {}
self.waiting_requirement = {}
def list_modules(self):
def list_modules(self) -> typing.List[str]:
return sorted(glob.glob(os.path.join(self.directory, "*.py")))
def _module_name(self, path):
def _module_name(self, path: str) -> str:
return os.path.basename(path).rsplit(".py", 1)[0].lower()
def _module_path(self, name):
def _module_path(self, name: str) -> str:
return os.path.join(self.directory, "%s.py" % name)
def _import_name(self, name):
def _import_name(self, name: str) -> str:
return "bitbot_%s" % name
def _get_magic(self, obj, magic, default):
def _get_magic(self, obj: typing.Any, magic: str, default: typing.Any
) -> typing.Any:
return getattr(obj, magic) if hasattr(obj, magic) else default
def _load_module(self, bot, name):
def _load_module(self, bot: "IRCBot.Bot", name: str):
path = self._module_path(name)
for hashflag, value in utils.get_hashflags(path):
for hashflag, value in utils.parse.hashflags(path):
if hashflag == "ignore":
# nope, ignore this module.
raise ModuleNotLoadedWarning("module ignored")
@ -97,10 +105,12 @@ class ModuleManager(object):
module_object._name = name.title()
for attribute_name in dir(module_object):
attribute = getattr(module_object, attribute_name)
for hook in self._get_magic(attribute, BITBOT_HOOKS_MAGIC, []):
for hook in self._get_magic(attribute,
utils.consts.BITBOT_HOOKS_MAGIC, []):
context_events.on(hook["event"]).hook(attribute,
**hook["kwargs"])
for export in self._get_magic(module_object, BITBOT_EXPORTS_MAGIC, []):
for export in self._get_magic(module_object,
utils.consts.BITBOT_EXPORTS_MAGIC, []):
context_exports.add(export["setting"], export["value"])
module_object._context = context
@ -111,7 +121,7 @@ class ModuleManager(object):
"attempted to be used twice")
return module_object
def load_module(self, bot, name):
def load_module(self, bot: "IRCBot.Bot", name: str):
try:
module = self._load_module(bot, name)
except ModuleWarning as warning:
@ -128,7 +138,8 @@ class ModuleManager(object):
self.load_module(bot, requirement_name)
self.log.info("Module '%s' loaded", [name])
def load_modules(self, bot, whitelist=[], blacklist=[]):
def load_modules(self, bot: "IRCBot.Bot", whitelist: typing.List[str]=[],
blacklist: typing.List[str]=[]):
for path in self.list_modules():
name = self._module_name(path)
if name in whitelist or (not whitelist and not name in blacklist):
@ -137,7 +148,7 @@ class ModuleManager(object):
except ModuleWarning:
pass
def unload_module(self, name):
def unload_module(self, name: str):
if not name in self.modules:
raise ModuleNotFoundException()
module = self.modules[name]

View file

@ -1,7 +1,9 @@
import socket, typing
class Socket(object):
def __init__(self, socket, on_read, encoding="utf8"):
def __init__(self, socket: socket.socket,
on_read: typing.Callable[["Socket", str], None],
encoding: str="utf8"):
self.socket = socket
self._on_read = on_read
self.encoding = encoding
@ -12,18 +14,18 @@ class Socket(object):
self.length = None
self.connected = True
def fileno(self):
def fileno(self) -> int:
return self.socket.fileno()
def disconnect(self):
self.connected = False
def _decode(self, s):
return s.decode(self.encoding) if self.encoding else s
def _encode(self, s):
return s.encode(self.encoding) if self.encoding else s
def _decode(self, s: bytes) -> str:
return s.decode(self.encoding)
def _encode(self, s: str) -> bytes:
return s.encode(self.encoding)
def read(self):
def read(self) -> typing.Optional[typing.List[str]]:
data = self.socket.recv(1024)
if not data:
return None
@ -35,17 +37,17 @@ class Socket(object):
if data_split[-1]:
self._read_buffer = data_split.pop(-1)
return [self._decode(data) for data in data_split]
return [data.decode(self.encoding)]
return [self._decode(data)]
def parse_data(self, data):
def parse_data(self, data: str):
self._on_read(self, data)
def send(self, data):
def send(self, data: str):
self._write_buffer += self._encode(data)
def _send(self):
self._write_buffer = self._write_buffer[self.socket.send(
self._write_buffer):]
def waiting_send(self):
def waiting_send(self) -> bool:
return bool(len(self._write_buffer))

View file

@ -1,7 +1,9 @@
import time, uuid
import time, typing, uuid
from src import Database, EventManager, Logging
class Timer(object):
def __init__(self, id, context, name, delay, next_due, kwargs):
def __init__(self, id: int, context: str, name: str, delay: float,
next_due: float, kwargs: dict):
self.id = id
self.context = context
self.name = name
@ -15,9 +17,9 @@ class Timer(object):
def set_next_due(self):
self.next_due = time.time()+self.delay
def due(self):
def due(self) -> bool:
return self.time_left() <= 0
def time_left(self):
def time_left(self) -> float:
return self.next_due-time.time()
def redo(self):
@ -25,42 +27,33 @@ class Timer(object):
self.set_next_due()
def finish(self):
self._done = True
def done(self):
def done(self) -> bool:
return self._done
class TimersContext(object):
def __init__(self, parent, context):
self._parent = parent
self.context = context
def add(self, name, delay, next_due=None, **kwargs):
self._parent._add(self.context, name, delay, next_due, None, False,
kwargs)
def add_persistent(self, name, delay, next_due=None, **kwargs):
self._parent._add(None, name, delay, next_due, None, True,
kwargs)
class Timers(object):
def __init__(self, database, events, log):
def __init__(self, database: Database.Database,
events: EventManager.EventHook,
log: Logging.Log):
self.database = database
self.events = events
self.log = log
self.timers = []
self.context_timers = {}
def new_context(self, context):
def new_context(self, context: str) -> "TimersContext":
return TimersContext(self, context)
def setup(self, timers):
def setup(self, timers: typing.List[typing.Tuple[str, dict]]):
for name, timer in timers:
id = name.split("timer-", 1)[1]
self._add(timer["name"], None, timer["delay"], timer[
"next-due"], id, False, timer["kwargs"])
def _persist(self, timer):
def _persist(self, timer: Timer):
self.database.bot_settings.set("timer-%s" % timer.id, {
"name": timer.name, "delay": timer.delay,
"next-due": timer.next_due, "kwargs": timer.kwargs})
def _remove(self, timer):
def _remove(self, timer: Timer):
if timer.context:
self.context_timers[timer.context].remove(timer)
if not self.context_timers[timer.context]:
@ -69,11 +62,13 @@ class Timers(object):
self.timers.remove(timer)
self.database.bot_settings.delete("timer-%s" % timer.id)
def add(self, name, delay, next_due=None, **kwargs):
def add(self, name: str, delay: float, next_due: float=None, **kwargs):
self._add(None, name, delay, next_due, None, False, kwargs)
def add_persistent(self, name, delay, next_due=None, **kwargs):
def add_persistent(self, name: str, delay: float, next_due: float=None,
**kwargs):
self._add(None, name, delay, next_due, None, True, kwargs)
def _add(self, context, name, delay, next_due, id, persist, kwargs):
def _add(self, context: str, name: str, delay: float, next_due: float,
id: str, persist: bool, kwargs: dict):
id = id or uuid.uuid4().hex
timer = Timer(id, context, name, delay, next_due, kwargs)
if persist:
@ -86,13 +81,13 @@ class Timers(object):
else:
self.timers.append(timer)
def next(self):
def next(self) -> float:
times = filter(None, [timer.time_left() for timer in self.get_timers()])
if not times:
return None
return max(min(times), 0)
def get_timers(self):
def get_timers(self) -> typing.List[Timer]:
return self.timers + sum(self.context_timers.values(), [])
def call(self):
@ -104,6 +99,19 @@ class Timers(object):
if timer.done():
self._remove(timer)
def purge_context(self, context):
def purge_context(self, context: str):
if context in self.context_timers:
del self.context_timers[context]
class TimersContext(object):
def __init__(self, parent: Timers, context: str):
self._parent = parent
self.context = context
def add(self, name: str, delay: float, next_due: float=None,
**kwargs):
self._parent._add(self.context, name, delay, next_due, None, False,
kwargs)
def add_persistent(self, name: str, delay: float, next_due: float=None,
**kwargs):
self._parent._add(None, name, delay, next_due, None, True,
kwargs)

View file

@ -1,6 +1,5 @@
import decimal, io, re
from src import ModuleManager
from . import irc, http
import decimal, io, re, typing
from src.utils import consts, irc, http, parse
TIME_SECOND = 1
TIME_MINUTE = TIME_SECOND*60
@ -8,7 +7,7 @@ TIME_HOUR = TIME_MINUTE*60
TIME_DAY = TIME_HOUR*24
TIME_WEEK = TIME_DAY*7
def time_unit(seconds):
def time_unit(seconds: int) -> typing.Tuple[int, str]:
since = None
unit = None
if seconds >= TIME_WEEK:
@ -29,7 +28,7 @@ def time_unit(seconds):
since = int(since)
if since > 1:
unit = "%ss" % unit # pluralise the unit
return [since, unit]
return (since, unit)
REGEX_PRETTYTIME = re.compile("\d+[wdhms]", re.I)
@ -38,7 +37,7 @@ SECONDS_HOURS = SECONDS_MINUTES*60
SECONDS_DAYS = SECONDS_HOURS*24
SECONDS_WEEKS = SECONDS_DAYS*7
def from_pretty_time(pretty_time):
def from_pretty_time(pretty_time: str) -> typing.Optional[int]:
seconds = 0
for match in re.findall(REGEX_PRETTYTIME, pretty_time):
number, unit = int(match[:-1]), match[-1].lower()
@ -54,12 +53,14 @@ def from_pretty_time(pretty_time):
if seconds > 0:
return seconds
UNIT_MINIMUM = 6
UNIT_SECOND = 5
UNIT_MINUTE = 4
UNIT_HOUR = 3
UNIT_DAY = 2
UNIT_WEEK = 1
def to_pretty_time(total_seconds, minimum_unit=UNIT_SECOND, max_units=6):
def to_pretty_time(total_seconds: int, minimum_unit: int=UNIT_SECOND,
max_units: int=UNIT_MINIMUM) -> str:
minutes, seconds = divmod(total_seconds, 60)
hours, minutes = divmod(minutes, 60)
days, hours = divmod(hours, 24)
@ -84,7 +85,7 @@ def to_pretty_time(total_seconds, minimum_unit=UNIT_SECOND, max_units=6):
units += 1
return out
def parse_number(s):
def parse_number(s: str) -> str:
try:
decimal.Decimal(s)
return s
@ -110,28 +111,18 @@ def parse_number(s):
IS_TRUE = ["true", "yes", "on", "y"]
IS_FALSE = ["false", "no", "off", "n"]
def bool_or_none(s):
def bool_or_none(s: str) -> typing.Optional[bool]:
s = s.lower()
if s in IS_TRUE:
return True
elif s in IS_FALSE:
return False
def int_or_none(s):
def int_or_none(s: str) -> typing.Optional[int]:
stripped_s = s.lstrip("0")
if stripped_s.isdigit():
return int(stripped_s)
def get_closest_setting(event, setting, default=None):
server = event["server"]
if "channel" in event:
closest = event["channel"]
elif "target" in event and "is_channel" in event and event["is_channel"]:
closest = event["target"]
else:
closest = event["user"]
return closest.get_setting(setting, server.get_setting(setting, default))
def prevent_highlight(nickname):
def prevent_highlight(nickname: str) -> str:
return nickname[0]+"\u200c"+nickname[1:]
class EventError(Exception):
@ -139,86 +130,40 @@ class EventError(Exception):
class EventsResultsError(EventError):
def __init__(self):
EventError.__init__(self, "Failed to load results")
class EventsNotEnoughArgsError(EventError):
def __init__(self, n):
EventError.__init__(self, "Not enough arguments (minimum %d)" % n)
class EventsUsageError(EventError):
def __init__(self, usage):
EventError.__init__(self, "Not enough arguments, usage: %s" % usage)
def _set_get_append(obj, setting, item):
def _set_get_append(obj: typing.Any, setting: str, item: typing.Any):
if not hasattr(obj, setting):
setattr(obj, setting, [])
getattr(obj, setting).append(item)
def hook(event, **kwargs):
def hook(event: str, **kwargs):
def _hook_func(func):
_set_get_append(func, ModuleManager.BITBOT_HOOKS_MAGIC,
_set_get_append(func, consts.BITBOT_HOOKS_MAGIC,
{"event": event, "kwargs": kwargs})
return func
return _hook_func
def export(setting, value):
def export(setting: str, value: typing.Any):
def _export_func(module):
_set_get_append(module, ModuleManager.BITBOT_EXPORTS_MAGIC,
_set_get_append(module, consts.BITBOT_EXPORTS_MAGIC,
{"setting": setting, "value": value})
return module
return _export_func
COMMENT_TYPES = ["#", "//"]
def get_hashflags(filename):
hashflags = {}
with io.open(filename, mode="r", encoding="utf8") as f:
for line in f:
line = line.strip("\n")
found = False
for comment_type in COMMENT_TYPES:
if line.startswith(comment_type):
line = line.replace(comment_type, "", 1).lstrip()
found = True
break
TOP_10_CALLABLE = typing.Callable[[typing.Any], typing.Any]
def top_10(items: typing.List[typing.Any],
convert_key: TOP_10_CALLABLE=lambda x: x,
value_format: TOP_10_CALLABLE=lambda x: x):
top_10 = sorted(items.keys())
top_10 = sorted(top_10, key=items.get, reverse=True)[:10]
if not found:
break
elif line.startswith("--"):
hashflag, sep, value = line[2:].partition(" ")
hashflags[hashflag] = value if sep else None
return hashflags.items()
top_10_items = []
for key in top_10:
top_10_items.append("%s (%s)" % (convert_key(key),
value_format(items[key])))
class Docstring(object):
def __init__(self, description, items, var_items):
self.description = description
self.items = items
self.var_items = var_items
def parse_docstring(s):
description = ""
last_item = None
items = {}
var_items = {}
if s:
for line in s.split("\n"):
line = line.strip()
if line:
if line[0] == ":":
key, _, value = line[1:].partition(": ")
last_item = key
if key in var_items:
var_items[key].append(value)
elif key in items:
var_items[key] = [items.pop(key), value]
else:
items[key] = value
else:
if last_item:
items[last_item] += " %s" % line
else:
if description:
description += " "
description += line
return Docstring(description, items, var_items)
def top_10(items, convert_key=lambda x: x, value_format=lambda x: x):
top_10 = sorted(items.keys())
top_10 = sorted(top_10, key=items.get, reverse=True)[:10]
top_10_items = []
for key in top_10:
top_10_items.append("%s (%s)" % (convert_key(key),
value_format(items[key])))
return top_10_items
return top_10_items

2
src/utils/consts.py Normal file
View file

@ -0,0 +1,2 @@
BITBOT_HOOKS_MAGIC = "__bitbot_hooks"
BITBOT_EXPORTS_MAGIC = "__bitbot_exports"

View file

@ -1,4 +1,4 @@
import re, signal, traceback, urllib.error, urllib.parse
import re, signal, traceback, typing, urllib.error, urllib.parse
import json as _json
import bs4, requests
@ -18,9 +18,10 @@ class HTTPParsingException(HTTPException):
def throw_timeout():
raise HTTPTimeoutException()
def get_url(url, method="GET", get_params={}, post_data=None, headers={},
json_data=None, code=False, json=False, soup=False, parser="lxml",
fallback_encoding="utf8"):
def get_url(url: str, method: str="GET", get_params: dict={},
post_data: typing.Any=None, headers: dict={},
json_data: typing.Any=None, code: bool=False, json: bool=False,
soup: bool=False, parser: str="lxml", fallback_encoding: str="utf8"):
if not urllib.parse.urlparse(url).scheme:
url = "http://%s" % url
@ -66,6 +67,6 @@ def get_url(url, method="GET", get_params={}, post_data=None, headers={},
else:
return data
def strip_html(s):
def strip_html(s: str) -> str:
return bs4.BeautifulSoup(s, "lxml").get_text()

View file

@ -1,4 +1,4 @@
import string, re
import string, re, typing
ASCII_UPPER = string.ascii_uppercase
ASCII_LOWER = string.ascii_lowercase
@ -7,32 +7,36 @@ STRICT_RFC1459_LOWER = ASCII_LOWER+r'|{}'
RFC1459_UPPER = STRICT_RFC1459_UPPER+"^"
RFC1459_LOWER = STRICT_RFC1459_LOWER+"~"
def remove_colon(s):
def remove_colon(s: str) -> str:
if s.startswith(":"):
s = s[1:]
return s
MULTI_REPLACE_ITERABLE = typing.Iterable[str]
# case mapping lowercase/uppcase logic
def _multi_replace(s, chars1, chars2):
def _multi_replace(s: str,
chars1: typing.Iterable[str],
chars2: typing.Iterable[str]) -> str:
for char1, char2 in zip(chars1, chars2):
s = s.replace(char1, char2)
return s
def lower(server, s):
if server.case_mapping == "ascii":
def lower(case_mapping: str, s: str) -> str:
if case_mapping == "ascii":
return _multi_replace(s, ASCII_UPPER, ASCII_LOWER)
elif server.case_mapping == "rfc1459":
elif case_mapping == "rfc1459":
return _multi_replace(s, RFC1459_UPPER, RFC1459_LOWER)
elif server.case_mapping == "strict-rfc1459":
elif case_mapping == "strict-rfc1459":
return _multi_replace(s, STRICT_RFC1459_UPPER, STRICT_RFC1459_LOWER)
else:
raise ValueError("unknown casemapping '%s'" % server.case_mapping)
raise ValueError("unknown casemapping '%s'" % case_mapping)
# compare a string while respecting case mapping
def equals(server, s1, s2):
return lower(server, s1) == lower(server, s2)
def equals(case_mapping: str, s1: str, s2: str) -> bool:
return lower(case_mapping, s1) == lower(case_mapping, s2)
class IRCHostmask(object):
def __init__(self, nickname, username, hostname, hostmask):
def __init__(self, nickname: str, username: str, hostname: str,
hostmask: str):
self.nickname = nickname
self.username = username
self.hostname = hostname
@ -42,24 +46,24 @@ class IRCHostmask(object):
def __str__(self):
return self.hostmask
def seperate_hostmask(hostmask):
def seperate_hostmask(hostmask: str) -> IRCHostmask:
hostmask = remove_colon(hostmask)
nickname, _, username = hostmask.partition("!")
username, _, hostname = username.partition("@")
return IRCHostmask(nickname, username, hostname, hostmask)
class IRCLine(object):
def __init__(self, tags, prefix, command, args, arbitrary, last, server):
def __init__(self, tags: dict, prefix: str, command: str,
args: typing.List[str], arbitrary: typing.Optional[str],
last: str):
self.tags = tags
self.prefix = prefix
self.command = command
self.args = args
self.arbitrary = arbitrary
self.last = last
self.server = server
def parse_line(server, line):
def parse_line(line: str) -> IRCLine:
tags = {}
prefix = None
command = None
@ -81,7 +85,7 @@ def parse_line(server, line):
args = line.split(" ")
last = arbitrary or args[-1]
return IRCLine(tags, prefix, command, args, arbitrary, last, server)
return IRCLine(tags, prefix, command, args, arbitrary, last)
COLOR_WHITE, COLOR_BLACK, COLOR_BLUE, COLOR_GREEN = 0, 1, 2, 3
COLOR_RED, COLOR_BROWN, COLOR_PURPLE, COLOR_ORANGE = 4, 5, 6, 7
@ -94,20 +98,20 @@ FONT_BOLD, FONT_ITALIC, FONT_UNDERLINE, FONT_INVERT = ("\x02", "\x1D",
FONT_COLOR, FONT_RESET = "\x03", "\x0F"
REGEX_COLOR = re.compile("%s\d\d(?:,\d\d)?" % FONT_COLOR)
def color(s, foreground, background=None):
def color(s: str, foreground: str, background: str=None) -> str:
foreground = str(foreground).zfill(2)
if background:
background = str(background).zfill(2)
return "%s%s%s%s%s" % (FONT_COLOR, foreground,
"" if not background else ",%s" % background, s, FONT_COLOR)
def bold(s):
def bold(s: str) -> str:
return "%s%s%s" % (FONT_BOLD, s, FONT_BOLD)
def underline(s):
def underline(s: str) -> str:
return "%s%s%s" % (FONT_UNDERLINE, s, FONT_UNDERLINE)
def strip_font(s):
def strip_font(s: str) -> str:
s = s.replace(FONT_BOLD, "")
s = s.replace(FONT_ITALIC, "")
s = REGEX_COLOR.sub("", s)

57
src/utils/parse.py Normal file
View file

@ -0,0 +1,57 @@
import io, typing
COMMENT_TYPES = ["#", "//"]
def hashflags(filename: str) -> typing.List[typing.Tuple[str, str]]:
hashflags = {}
with io.open(filename, mode="r", encoding="utf8") as f:
for line in f:
line = line.strip("\n")
found = False
for comment_type in COMMENT_TYPES:
if line.startswith(comment_type):
line = line.replace(comment_type, "", 1).lstrip()
found = True
break
if not found:
break
elif line.startswith("--"):
hashflag, sep, value = line[2:].partition(" ")
hashflags[hashflag] = value if sep else None
return list(hashflags.items())
class Docstring(object):
def __init__(self, description: str, items: dict, var_items: dict):
self.description = description
self.items = items
self.var_items = var_items
def docstring(s: str) -> Docstring:
description = ""
last_item = None
items = {}
var_items = {}
if s:
for line in s.split("\n"):
line = line.strip()
if line:
if line[0] == ":":
key, _, value = line[1:].partition(": ")
last_item = key
if key in var_items:
var_items[key].append(value)
elif key in items:
var_items[key] = [items.pop(key), value]
else:
items[key] = value
else:
if last_item:
items[last_item] += " %s" % line
else:
if description:
description += " "
description += line
return Docstring(description, items, var_items)