Add type/return hints throughout src/ and, in doing so, fix some cyclical

references.
This commit is contained in:
jesopo 2018-10-30 14:58:48 +00:00
parent 705daaf9bb
commit e07553c362
22 changed files with 605 additions and 516 deletions

View file

@ -27,7 +27,7 @@ class Module(ModuleManager.BaseModule):
if not event["user"].last_karma or (time.time()-event["user" if not event["user"].last_karma or (time.time()-event["user"
].last_karma) >= KARMA_DELAY_SECONDS: ].last_karma) >= KARMA_DELAY_SECONDS:
target = match.group(1).strip() target = match.group(1).strip()
if utils.irc.lower(event["server"], target if utils.irc.lower(event["server"].case_mapping, target
) == event["user"].name: ) == event["user"].name:
if verbose: if verbose:
self.events.on("send.stderr").call( self.events.on("send.stderr").call(

View file

@ -542,7 +542,8 @@ class Module(ModuleManager.BaseModule):
# we need a registered nickname for this channel # we need a registered nickname for this channel
@utils.hook("raw.477", default_event=True) @utils.hook("raw.477", default_event=True)
def handle_477(self, event): def handle_477(self, event):
channel_name = utils.irc.lower(event["server"], event["args"][1]) channel_name = utils.irc.lower(event["server"].case_mapping,
event["args"][1])
if channel_name in event["server"]: if channel_name in event["server"]:
key = event["server"].attempted_join[channel_name] key = event["server"].attempted_join[channel_name]
self.timers.add("rejoin", 5, channel_name=channe_name, key=key, self.timers.add("rejoin", 5, channel_name=channe_name, key=key,

View file

@ -11,12 +11,16 @@ REGEX_SED = re.compile("^s/")
"help": "Disable/Enable sed only looking at the messages sent by the user", "help": "Disable/Enable sed only looking at the messages sent by the user",
"validate": utils.bool_or_none}) "validate": utils.bool_or_none})
class Module(ModuleManager.BaseModule): class Module(ModuleManager.BaseModule):
def _closest_setting(self, event, setting, default):
return event["channel"].get_setting(setting,
event["server"].get_setting(setting, default))
@utils.hook("received.message.channel") @utils.hook("received.message.channel")
def channel_message(self, event): def channel_message(self, event):
sed_split = re.split(REGEX_SPLIT, event["message"], 3) sed_split = re.split(REGEX_SPLIT, event["message"], 3)
if event["message"].startswith("s/") and len(sed_split) > 2: if event["message"].startswith("s/") and len(sed_split) > 2:
if event["action"] or not utils.get_closest_setting( if event["action"] or not self._closest_setting(event, "sed",
event, "sed", False): False):
return return
regex_flags = 0 regex_flags = 0
@ -48,9 +52,8 @@ class Module(ModuleManager.BaseModule):
return return
replace = sed_split[2].replace("\\/", "/") replace = sed_split[2].replace("\\/", "/")
for_user = event["user"].nickname if utils.get_closest_setting( for_user = event["user"].nickname if self._closest_setting(event,
event, "sed-sender-only", False "sed-sender-only", False) else None
) else None
line = event["channel"].buffer.find(pattern, from_self=False, line = event["channel"].buffer.find(pattern, from_self=False,
for_user=for_user, not_pattern=REGEX_SED) for_user=for_user, not_pattern=REGEX_SED)
if line: if line:

View file

@ -1,21 +1,21 @@
import time, uuid import time, typing, uuid
class Cache(object): class Cache(object):
def __init__(self): def __init__(self):
self._items = {} self._items = {}
self._item_to_id = {} self._item_to_id = {}
def cache(self, item): def cache(self, item: typing.Any) -> str:
return self._cache(item, None) return self._cache(item, None)
def temporary_cache(self, item, timeout): def temporary_cache(self, item: typing.Any, timeout: float)-> str:
return self._cache(item, timeout) return self._cache(item, timeout)
def _cache(self, item, timeout): def _cache(self, item: typing.Any, timeout: float) -> str:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
self._items[id] = [item, time.monotonic()+timeout] self._items[id] = [item, time.monotonic()+timeout]
self._item_to_id[item] = id self._item_to_id[item] = id
return id return id
def next_expiration(self): def next_expiration(self) -> float:
expirations = [self._items[id][1] for id in self._items] expirations = [self._items[id][1] for id in self._items]
expirations = list(filter(None, expirations)) expirations = list(filter(None, expirations))
if not expirations: if not expirations:
@ -35,17 +35,17 @@ class Cache(object):
del self._items[id] del self._items[id]
del self._item_to_id[item] del self._item_to_id[item]
def has_item(self, item): def has_item(self, item: typing.Any) -> bool:
return item in self._item_to_id return item in self._item_to_id
def get(self, id): def get(self, id: str) -> typing.Any:
item, expiration = self._items[id] item, expiration = self._items[id]
return item return item
def get_expiration(self, item): def get_expiration(self, item: typing.Any) -> float:
id = self._item_to_id[item] id = self._item_to_id[item]
item, expiration = self._items[id] item, expiration = self._items[id]
return expiration return expiration
def until_expiration(self, item): def until_expiration(self, item: typing.Any) -> float:
expiration = self.get_expiration(item) expiration = self.get_expiration(item)
return expiration-time.monotonic() return expiration-time.monotonic()

View file

@ -1,7 +1,7 @@
import configparser, os import configparser, os, typing
class Config(object): class Config(object):
def __init__(self, location): def __init__(self, location: str):
self.location = location self.location = location
self._config = {} self._config = {}
self.load() self.load()
@ -13,10 +13,10 @@ class Config(object):
parser.read_string(config_file.read()) parser.read_string(config_file.read())
self._config = dict(parser["bot"].items()) self._config = dict(parser["bot"].items())
def __getitem__(self, key): def __getitem__(self, key: str) -> typing.Any:
return self._config[key] return self._config[key]
def get(self, key, default=None): def get(self, key: str, default: typing.Any=None) -> typing.Any:
return self._config.get(key, default) return self._config.get(key, default)
def __contains__(self, key): def __contains__(self, key: str) -> bool:
return key in self._config return key in self._config

View file

@ -1,12 +1,14 @@
import json, os, sqlite3, threading, time import json, os, sqlite3, threading, time, typing
from src import Logging
class Table(object): class Table(object):
def __init__(self, database): def __init__(self, database):
self.database = database self.database = database
class Servers(Table): class Servers(Table):
def add(self, alias, hostname, port, password, ipv4, tls, bindhost, def add(self, alias: str, hostname: str, port: int, password: str,
nickname, username=None, realname=None): ipv4: bool, tls: bool, bindhost: str,
nickname: str, username: str=None, realname: str=None):
username = username or nickname username = username or nickname
realname = realname or nickname realname = realname or nickname
self.database.execute( self.database.execute(
@ -18,7 +20,7 @@ class Servers(Table):
def get_all(self): def get_all(self):
return self.database.execute_fetchall( return self.database.execute_fetchall(
"SELECT server_id, alias FROM servers") "SELECT server_id, alias FROM servers")
def get(self, id): def get(self, id: int):
return self.database.execute_fetchone( return self.database.execute_fetchone(
"""SELECT server_id, alias, hostname, port, password, ipv4, """SELECT server_id, alias, hostname, port, password, ipv4,
tls, bindhost, nickname, username, realname FROM servers WHERE tls, bindhost, nickname, username, realname FROM servers WHERE
@ -26,46 +28,46 @@ class Servers(Table):
[id]) [id])
class Channels(Table): class Channels(Table):
def add(self, server_id, name): def add(self, server_id: int, name: str):
self.database.execute("""INSERT OR IGNORE INTO channels self.database.execute("""INSERT OR IGNORE INTO channels
(server_id, name) VALUES (?, ?)""", (server_id, name) VALUES (?, ?)""",
[server_id, name.lower()]) [server_id, name.lower()])
def delete(self, channel_id): def delete(self, channel_id: int):
self.database.execute("DELETE FROM channels WHERE channel_id=?", self.database.execute("DELETE FROM channels WHERE channel_id=?",
[channel_id]) [channel_id])
def get_id(self, server_id, name): def get_id(self, server_id: int, name: str):
value = self.database.execute_fetchone("""SELECT channel_id FROM value = self.database.execute_fetchone("""SELECT channel_id FROM
channels WHERE server_id=? AND name=?""", channels WHERE server_id=? AND name=?""",
[server_id, name.lower()]) [server_id, name.lower()])
return value if value == None else value[0] return value if value == None else value[0]
class Users(Table): class Users(Table):
def add(self, server_id, nickname): def add(self, server_id: int, nickname: str):
self.database.execute("""INSERT OR IGNORE INTO users self.database.execute("""INSERT OR IGNORE INTO users
(server_id, nickname) VALUES (?, ?)""", (server_id, nickname) VALUES (?, ?)""",
[server_id, nickname.lower()]) [server_id, nickname.lower()])
def delete(self, user_id): def delete(self, user_id: int):
self.database.execute("DELETE FROM users WHERE user_id=?", self.database.execute("DELETE FROM users WHERE user_id=?",
[user_id]) [user_id])
def get_id(self, server_id, nickname): def get_id(self, server_id: int, nickname: str):
value = self.database.execute_fetchone("""SELECT user_id FROM value = self.database.execute_fetchone("""SELECT user_id FROM
users WHERE server_id=? and nickname=?""", users WHERE server_id=? and nickname=?""",
[server_id, nickname.lower()]) [server_id, nickname.lower()])
return value if value == None else value[0] return value if value == None else value[0]
class BotSettings(Table): class BotSettings(Table):
def set(self, setting, value): def set(self, setting: str, value: typing.Any):
self.database.execute( self.database.execute(
"INSERT OR REPLACE INTO bot_settings VALUES (?, ?)", "INSERT OR REPLACE INTO bot_settings VALUES (?, ?)",
[setting.lower(), json.dumps(value)]) [setting.lower(), json.dumps(value)])
def get(self, setting, default=None): def get(self, setting: str, default: typing.Any=None):
value = self.database.execute_fetchone( value = self.database.execute_fetchone(
"SELECT value FROM bot_settings WHERE setting=?", "SELECT value FROM bot_settings WHERE setting=?",
[setting.lower()]) [setting.lower()])
if value: if value:
return json.loads(value[0]) return json.loads(value[0])
return default return default
def find(self, pattern, default=[]): def find(self, pattern: str, default: typing.Any=[]):
values = self.database.execute_fetchall( values = self.database.execute_fetchall(
"SELECT setting, value FROM bot_settings WHERE setting LIKE ?", "SELECT setting, value FROM bot_settings WHERE setting LIKE ?",
[pattern.lower()]) [pattern.lower()])
@ -74,19 +76,19 @@ class BotSettings(Table):
values[i] = value[0], json.loads(value[1]) values[i] = value[0], json.loads(value[1])
return values return values
return default return default
def find_prefix(self, prefix, default=[]): def find_prefix(self, prefix: str, default: typing.Any=[]):
return self.find("%s%%" % prefix, default) return self.find("%s%%" % prefix, default)
def delete(self, setting): def delete(self, setting: str):
self.database.execute( self.database.execute(
"DELETE FROM bot_settings WHERE setting=?", "DELETE FROM bot_settings WHERE setting=?",
[setting.lower()]) [setting.lower()])
class ServerSettings(Table): class ServerSettings(Table):
def set(self, server_id, setting, value): def set(self, server_id: int, setting: str, value: typing.Any):
self.database.execute( self.database.execute(
"INSERT OR REPLACE INTO server_settings VALUES (?, ?, ?)", "INSERT OR REPLACE INTO server_settings VALUES (?, ?, ?)",
[server_id, setting.lower(), json.dumps(value)]) [server_id, setting.lower(), json.dumps(value)])
def get(self, server_id, setting, default=None): def get(self, server_id: int, setting: str, default: typing.Any=None):
value = self.database.execute_fetchone( value = self.database.execute_fetchone(
"""SELECT value FROM server_settings WHERE """SELECT value FROM server_settings WHERE
server_id=? AND setting=?""", server_id=? AND setting=?""",
@ -94,7 +96,7 @@ class ServerSettings(Table):
if value: if value:
return json.loads(value[0]) return json.loads(value[0])
return default return default
def find(self, server_id, pattern, default=[]): def find(self, server_id: int, pattern: str, default: typing.Any=[]):
values = self.database.execute_fetchall( values = self.database.execute_fetchall(
"""SELECT setting, value FROM server_settings WHERE """SELECT setting, value FROM server_settings WHERE
server_id=? AND setting LIKE ?""", server_id=? AND setting LIKE ?""",
@ -104,26 +106,26 @@ class ServerSettings(Table):
values[i] = value[0], json.loads(value[1]) values[i] = value[0], json.loads(value[1])
return values return values
return default return default
def find_prefix(self, server_id, prefix, default=[]): def find_prefix(self, server_id: int, prefix: str, default: typing.Any=[]):
return self.find_server_settings(server_id, "%s%%" % prefix, default) return self.find_server_settings(server_id, "%s%%" % prefix, default)
def delete(self, server_id, setting): def delete(self, server_id: int, setting: str):
self.database.execute( self.database.execute(
"DELETE FROM server_settings WHERE server_id=? AND setting=?", "DELETE FROM server_settings WHERE server_id=? AND setting=?",
[server_id, setting.lower()]) [server_id, setting.lower()])
class ChannelSettings(Table): class ChannelSettings(Table):
def set(self, channel_id, setting, value): def set(self, channel_id: int, setting: str, value: typing.Any):
self.database.execute( self.database.execute(
"INSERT OR REPLACE INTO channel_settings VALUES (?, ?, ?)", "INSERT OR REPLACE INTO channel_settings VALUES (?, ?, ?)",
[channel_id, setting.lower(), json.dumps(value)]) [channel_id, setting.lower(), json.dumps(value)])
def get(self, channel_id, setting, default=None): def get(self, channel_id: int, setting: str, default: typing.Any=None):
value = self.database.execute_fetchone( value = self.database.execute_fetchone(
"""SELECT value FROM channel_settings WHERE """SELECT value FROM channel_settings WHERE
channel_id=? AND setting=?""", [channel_id, setting.lower()]) channel_id=? AND setting=?""", [channel_id, setting.lower()])
if value: if value:
return json.loads(value[0]) return json.loads(value[0])
return default return default
def find(self, channel_id, pattern, default=[]): def find(self, channel_id: int, pattern: str, default: typing.Any=[]):
values = self.database.execute_fetchall( values = self.database.execute_fetchall(
"""SELECT setting, value FROM channel_settings WHERE """SELECT setting, value FROM channel_settings WHERE
channel_id=? setting LIKE '?'""", [channel_id, pattern.lower()]) channel_id=? setting LIKE '?'""", [channel_id, pattern.lower()])
@ -132,15 +134,15 @@ class ChannelSettings(Table):
values[i] = value[0], json.loads(value[1]) values[i] = value[0], json.loads(value[1])
return values return values
return default return default
def find_prefix(self, channel_id, prefix, default=[]): def find_prefix(self, channel_id: int, prefix: str, default: typing.Any=[]):
return self.find_channel_settings(channel_id, "%s%%" % prefix, return self.find_channel_settings(channel_id, "%s%%" % prefix,
default) default)
def delete(self, channel_id, setting): def delete(self, channel_id: int, setting: str):
self.database.execute( self.database.execute(
"""DELETE FROM channel_settings WHERE channel_id=? """DELETE FROM channel_settings WHERE channel_id=?
AND setting=?""", [channel_id, setting.lower()]) AND setting=?""", [channel_id, setting.lower()])
def find_by_setting(self, setting, default=[]): def find_by_setting(self, setting: str, default: typing.Any=[]):
values = self.database.execute_fetchall( values = self.database.execute_fetchall(
"""SELECT channels.server_id, channels.name, """SELECT channels.server_id, channels.name,
channel_settings.value FROM channel_settings channel_settings.value FROM channel_settings
@ -154,18 +156,19 @@ class ChannelSettings(Table):
return default return default
class UserSettings(Table): class UserSettings(Table):
def set(self, user_id, setting, value): def set(self, user_id: int, setting: str, value: typing.Any):
self.database.execute( self.database.execute(
"INSERT OR REPLACE INTO user_settings VALUES (?, ?, ?)", "INSERT OR REPLACE INTO user_settings VALUES (?, ?, ?)",
[user_id, setting.lower(), json.dumps(value)]) [user_id, setting.lower(), json.dumps(value)])
def get(self, user_id, setting, default=None): def get(self, user_id: int, setting: str, default: typing.Any=None):
value = self.database.execute_fetchone( value = self.database.execute_fetchone(
"""SELECT value FROM user_settings WHERE """SELECT value FROM user_settings WHERE
user_id=? and setting=?""", [user_id, setting.lower()]) user_id=? and setting=?""", [user_id, setting.lower()])
if value: if value:
return json.loads(value[0]) return json.loads(value[0])
return default return default
def find_all_by_setting(self, server_id, setting, default=[]): def find_all_by_setting(self, server_id: int, setting: str,
default: typing.Any=[]):
values = self.database.execute_fetchall( values = self.database.execute_fetchall(
"""SELECT users.nickname, user_settings.value FROM """SELECT users.nickname, user_settings.value FROM
user_settings INNER JOIN users ON user_settings INNER JOIN users ON
@ -177,7 +180,7 @@ class UserSettings(Table):
values[i] = value[0], json.loads(value[1]) values[i] = value[0], json.loads(value[1])
return values return values
return default return default
def find(self, user_id, pattern, default=[]): def find(self, user_id: int, pattern: str, default: typing.Any=[]):
values = self.database.execute( values = self.database.execute(
"""SELECT setting, value FROM user_settings WHERE """SELECT setting, value FROM user_settings WHERE
user_id=? AND setting LIKE '?'""", [user_id, pattern.lower()]) user_id=? AND setting LIKE '?'""", [user_id, pattern.lower()])
@ -186,20 +189,22 @@ class UserSettings(Table):
values[i] = value[0], json.loads(value[1]) values[i] = value[0], json.loads(value[1])
return values return values
return default return default
def find_prefix(self, user_id, prefix, default=[]): def find_prefix(self, user_id: int, prefix: str, default: typing.Any=[]):
return self.find_user_settings(user_id, "%s%%" % prefix, default) return self.find_user_settings(user_id, "%s%%" % prefix, default)
def delete(self, user_id, setting): def delete(self, user_id: int, setting: str):
self.database.execute( self.database.execute(
"""DELETE FROM user_settings WHERE """DELETE FROM user_settings WHERE
user_id=? AND setting=?""", [user_id, setting.lower()]) user_id=? AND setting=?""", [user_id, setting.lower()])
class UserChannelSettings(Table): class UserChannelSettings(Table):
def set(self, user_id, channel_id, setting, value): def set(self, user_id: int, channel_id: int, setting: str,
value: typing.Any):
self.database.execute( self.database.execute(
"""INSERT OR REPLACE INTO user_channel_settings VALUES """INSERT OR REPLACE INTO user_channel_settings VALUES
(?, ?, ?, ?)""", (?, ?, ?, ?)""",
[user_id, channel_id, setting.lower(), json.dumps(value)]) [user_id, channel_id, setting.lower(), json.dumps(value)])
def get(self, user_id, channel_id, setting, default=None): def get(self, user_id: int, channel_id: int, setting: str,
default: typing.Any=None):
value = self.database.execute_fetchone( value = self.database.execute_fetchone(
"""SELECT value FROM user_channel_settings WHERE """SELECT value FROM user_channel_settings WHERE
user_id=? AND channel_id=? AND setting=?""", user_id=? AND channel_id=? AND setting=?""",
@ -207,7 +212,8 @@ class UserChannelSettings(Table):
if value: if value:
return json.loads(value[0]) return json.loads(value[0])
return default return default
def find(self, user_id, channel_id, pattern, default=[]): def find(self, user_id: int, channel_id: int, pattern: str,
default: typing.Any=[]):
values = self.database.execute_fetchall( values = self.database.execute_fetchall(
"""SELECT setting, value FROM user_channel_settings WHERE """SELECT setting, value FROM user_channel_settings WHERE
user_id=? AND channel_id=? AND setting LIKE '?'""", user_id=? AND channel_id=? AND setting LIKE '?'""",
@ -217,10 +223,12 @@ class UserChannelSettings(Table):
values[i] = value[0], json.loads(value[1]) values[i] = value[0], json.loads(value[1])
return values return values
return default return default
def find_prefix(self, user_id, channel_id, prefix, default=[]): def find_prefix(self, user_id: int, channel_id: int, prefix: str,
default: typing.Any=[]):
return self.find_user_settings(user_id, channel_id, "%s%%" % prefix, return self.find_user_settings(user_id, channel_id, "%s%%" % prefix,
default) default)
def find_by_setting(self, user_id, setting, default=[]): def find_by_setting(self, user_id: int, setting: str,
default: typing.Any=[]):
values = self.database.execute_fetchall( values = self.database.execute_fetchall(
"""SELECT channels.name, user_channel_settings.value FROM """SELECT channels.name, user_channel_settings.value FROM
user_channel_settings INNER JOIN channels ON user_channel_settings INNER JOIN channels ON
@ -232,7 +240,8 @@ class UserChannelSettings(Table):
values[i] = value[0], json.loads(value[1]) values[i] = value[0], json.loads(value[1])
return values return values
return default return default
def find_all_by_setting(self, server_id, setting, default=[]): def find_all_by_setting(self, server_id: int, setting: str,
default: typing.Any=[]):
values = self.database.execute_fetchall( values = self.database.execute_fetchall(
"""SELECT channels.name, users.nickname, """SELECT channels.name, users.nickname,
user_channel_settings.value FROM user_channel_settings.value FROM
@ -246,14 +255,14 @@ class UserChannelSettings(Table):
values[i] = value[0], value[1], json.loads(value[2]) values[i] = value[0], value[1], json.loads(value[2])
return values return values
return default return default
def delete(self, user_id, channel_id, setting): def delete(self, user_id: int, channel_id: int, setting: str):
self.database.execute( self.database.execute(
"""DELETE FROM user_channel_settings WHERE """DELETE FROM user_channel_settings WHERE
user_id=? AND channel_id=? AND setting=?""", user_id=? AND channel_id=? AND setting=?""",
[user_id, channel_id, setting.lower()]) [user_id, channel_id, setting.lower()])
class Database(object): class Database(object):
def __init__(self, log, location): def __init__(self, log: "Logging.Log", location: str):
self.log = log self.log = log
self.location = location self.location = location
self.database = sqlite3.connect(self.location, self.database = sqlite3.connect(self.location,
@ -284,7 +293,9 @@ class Database(object):
self._cursor = self.database.cursor() self._cursor = self.database.cursor()
return self._cursor return self._cursor
def _execute_fetch(self, query, fetch_func, params=[]): def _execute_fetch(self, query: str,
fetch_func: typing.Callable[[sqlite3.Cursor], typing.Any],
params: typing.List=[]):
printable_query = " ".join(query.split()) printable_query = " ".join(query.split())
self.log.trace("executing query: \"%s\" (params: %s)", self.log.trace("executing query: \"%s\" (params: %s)",
[printable_query, params]) [printable_query, params])
@ -299,16 +310,16 @@ class Database(object):
self.log.trace("executed in %fms", [total_milliseconds]) self.log.trace("executed in %fms", [total_milliseconds])
return value return value
def execute_fetchall(self, query, params=[]): def execute_fetchall(self, query: str, params: typing.List=[]):
return self._execute_fetch(query, return self._execute_fetch(query,
lambda cursor: cursor.fetchall(), params) lambda cursor: cursor.fetchall(), params)
def execute_fetchone(self, query, params=[]): def execute_fetchone(self, query: str, params: typing.List=[]):
return self._execute_fetch(query, return self._execute_fetch(query,
lambda cursor: cursor.fetchone(), params) lambda cursor: cursor.fetchone(), params)
def execute(self, query, params=[]): def execute(self, query: str, params: typing.List=[]):
return self._execute_fetch(query, lambda cursor: None, params) return self._execute_fetch(query, lambda cursor: None, params)
def has_table(self, table_name): def has_table(self, table_name: str):
result = self.execute_fetchone("""SELECT COUNT(*) FROM result = self.execute_fetchone("""SELECT COUNT(*) FROM
sqlite_master WHERE type='table' AND name=?""", sqlite_master WHERE type='table' AND name=?""",
[table_name]) [table_name])

View file

@ -1,5 +1,5 @@
import itertools, time, traceback import itertools, time, traceback, typing
from src import utils from src import Logging, utils
PRIORITY_URGENT = 0 PRIORITY_URGENT = 0
PRIORITY_HIGH = 1 PRIORITY_HIGH = 1
@ -11,94 +11,39 @@ DEFAULT_PRIORITY = PRIORITY_MEDIUM
DEFAULT_EVENT_DELIMITER = "." DEFAULT_EVENT_DELIMITER = "."
DEFAULT_MULTI_DELIMITER = "|" DEFAULT_MULTI_DELIMITER = "|"
CALLBACK_TYPE = typing.Callable[["Event"], typing.Any]
class Event(object): class Event(object):
def __init__(self, name, **kwargs): def __init__(self, name: str, **kwargs):
self.name = name self.name = name
self.kwargs = kwargs self.kwargs = kwargs
self.eaten = False self.eaten = False
def __getitem__(self, key): def __getitem__(self, key: str) -> typing.Any:
return self.kwargs[key] return self.kwargs[key]
def get(self, key, default=None): def get(self, key: str, default=None) -> typing.Any:
return self.kwargs.get(key, default) return self.kwargs.get(key, default)
def __contains__(self, key): def __contains__(self, key: str) -> bool:
return key in self.kwargs return key in self.kwargs
def eat(self): def eat(self):
self.eaten = True self.eaten = True
class EventCallback(object): class EventCallback(object):
def __init__(self, function, priority, kwargs): def __init__(self, function: CALLBACK_TYPE, priority: int, kwargs: dict):
self.function = function self.function = function
self.priority = priority self.priority = priority
self.kwargs = kwargs self.kwargs = kwargs
self.docstring = utils.parse_docstring(function.__doc__) self.docstring = utils.parse.docstring(function.__doc__)
def call(self, event): def call(self, event: Event) -> typing.Any:
return self.function(event) return self.function(event)
def get_kwarg(self, name, default=None): def get_kwarg(self, name: str, default=None) -> typing.Any:
item = self.kwargs.get(name, default) item = self.kwargs.get(name, default)
return item or self.docstring.items.get(name, default) return item or self.docstring.items.get(name, default)
class MultipleEventHook(object):
def __init__(self):
self._event_hooks = set([])
def _add(self, event_hook):
self._event_hooks.add(event_hook)
def hook(self, function, **kwargs):
for event_hook in self._event_hooks:
event_hook.hook(function, **kwargs)
def call_limited(self, maximum, **kwargs):
returns = []
for event_hook in self._event_hooks:
returns.append(event_hook.call_limited(maximum, **kwargs))
return returns
def call(self, **kwargs):
returns = []
for event_hook in self._event_hooks:
returns.append(event_hook.call(**kwargs))
return returns
class EventHookContext(object):
def __init__(self, parent, context):
self._parent = parent
self.context = context
def hook(self, function, priority=DEFAULT_PRIORITY, replay=False,
**kwargs):
return self._parent._context_hook(self.context, function, priority,
replay, kwargs)
def unhook(self, callback):
self._parent.unhook(callback)
def on(self, subevent, *extra_subevents,
delimiter=DEFAULT_EVENT_DELIMITER):
return self._parent._context_on(self.context, subevent,
extra_subevents, delimiter)
def call_for_result(self, default=None, **kwargs):
return self._parent.call_for_result(default, **kwargs)
def assure_call(self, **kwargs):
self._parent.assure_call(**kwargs)
def call(self, **kwargs):
return self._parent.call(**kwargs)
def call_limited(self, maximum, **kwargs):
return self._parent.call_limited(maximum, **kwargs)
def call_unsafe_for_result(self, default=None, **kwargs):
return self._parent.call_unsafe_for_result(default, **kwargs)
def call_unsafe(self, **kwargs):
return self._parent.call_unsafe(**kwargs)
def call_unsafe_limited(self, maximum, **kwargs):
return self._parent.call_unsafe_limited(maximum, **kwargs)
def get_hooks(self):
return self._parent.get_hooks()
def get_children(self):
return self._parent.get_children()
class EventHook(object): class EventHook(object):
def __init__(self, log, name=None, parent=None): def __init__(self, log: Logging.Log, name: str = None,
parent: "EventHook" = None):
self.log = log self.log = log
self.name = name self.name = name
self.parent = parent self.parent = parent
@ -107,10 +52,10 @@ class EventHook(object):
self._stored_events = [] self._stored_events = []
self._context_hooks = {} self._context_hooks = {}
def _make_event(self, kwargs): def _make_event(self, kwargs: dict) -> Event:
return Event(self._get_path(), **kwargs) return Event(self._get_path(), **kwargs)
def _get_path(self): def _get_path(self) -> str:
path = [] path = []
parent = self parent = self
while not parent == None and not parent.name == None: while not parent == None and not parent.name == None:
@ -118,15 +63,17 @@ class EventHook(object):
parent = parent.parent parent = parent.parent
return DEFAULT_EVENT_DELIMITER.join(path[::-1]) return DEFAULT_EVENT_DELIMITER.join(path[::-1])
def new_context(self, context): def new_context(self, context: str) -> "EventHookContext":
return EventHookContext(self, context) return EventHookContext(self, context)
def hook(self, function, priority=DEFAULT_PRIORITY, replay=False, def hook(self, function: CALLBACK_TYPE, priority: int = DEFAULT_PRIORITY,
**kwargs): replay: bool = False, **kwargs) -> EventCallback:
return self._hook(function, None, priority, replay, kwargs) return self._hook(function, None, priority, replay, kwargs)
def _context_hook(self, context, function, priority, replay, kwargs): def _context_hook(self, context: str, function: CALLBACK_TYPE,
priority: int, replay: bool, kwargs: dict) -> EventCallback:
return self._hook(function, context, priority, replay, kwargs) return self._hook(function, context, priority, replay, kwargs)
def _hook(self, function, context, priority, replay, kwargs): def _hook(self, function: CALLBACK_TYPE, context: str, priority: int,
replay: bool, kwargs: dict) -> EventCallback:
callback = EventCallback(function, priority, kwargs) callback = EventCallback(function, priority, kwargs)
if context == None: if context == None:
@ -142,7 +89,7 @@ class EventHook(object):
self._stored_events = None self._stored_events = None
return callback return callback
def unhook(self, callback): def unhook(self, callback: "EventHook"):
if callback in self._hooks: if callback in self._hooks:
self._hooks.remove(callback) self._hooks.remove(callback)
@ -155,7 +102,8 @@ class EventHook(object):
for context in empty: for context in empty:
del self._context_hooks[context] del self._context_hooks[context]
def _make_multiple_hook(self, source, context, events): def _make_multiple_hook(self, source: "EventHook", context: str,
events: typing.List[str]) -> "MultipleEventHook":
multiple_event_hook = MultipleEventHook() multiple_event_hook = MultipleEventHook()
for event in events: for event in events:
event_hook = source.get_child(event) event_hook = source.get_child(event)
@ -164,13 +112,15 @@ class EventHook(object):
multiple_event_hook._add(event_hook) multiple_event_hook._add(event_hook)
return multiple_event_hook return multiple_event_hook
def on(self, subevent, *extra_subevents, def on(self, subevent: str, *extra_subevents,
delimiter=DEFAULT_EVENT_DELIMITER): delimiter: int = DEFAULT_EVENT_DELIMITER) -> "EventHook":
return self._on(subevent, extra_subevents, None, delimiter) return self._on(subevent, extra_subevents, None, delimiter)
def _context_on(self, context, subevent, extra_subevents, def _context_on(self, context: str, subevent: str,
delimiter=DEFAULT_EVENT_DELIMITER): extra_subevents: typing.List[str],
delimiter: str = DEFAULT_EVENT_DELIMITER) -> "EventHook":
return self._on(subevent, extra_subevents, context, delimiter) return self._on(subevent, extra_subevents, context, delimiter)
def _on(self, subevent, extra_subevents, context, delimiter): def _on(self, subevent: str, extra_subevents: typing.List[str],
context: str, delimiter: str) -> "EventHook":
if delimiter in subevent: if delimiter in subevent:
event_chain = subevent.split(delimiter) event_chain = subevent.split(delimiter)
event_obj = self event_obj = self
@ -193,26 +143,28 @@ class EventHook(object):
child = child.new_context(context) child = child.new_context(context)
return child return child
def call_for_result(self, default=None, **kwargs): def call_for_result(self, default=None, **kwargs) -> typing.Any:
return (self.call_limited(1, **kwargs) or [default])[0] return (self.call_limited(1, **kwargs) or [default])[0]
def assure_call(self, **kwargs): def assure_call(self, **kwargs):
if not self._stored_events == None: if not self._stored_events == None:
self._stored_events.append(kwargs) self._stored_events.append(kwargs)
else: else:
self._call(kwargs, True, None) self._call(kwargs, True, None)
def call(self, **kwargs): def call(self, **kwargs) -> typing.List[typing.Any]:
return self._call(kwargs, True, None) return self._call(kwargs, True, None)
def call_limited(self, maximum, **kwargs): def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]:
return self._call(kwargs, True, None) return self._call(kwargs, True, None)
def call_unsafe_for_result(self, default=None, **kwargs): def call_unsafe_for_result(self, default=None, **kwargs) -> typing.Any:
return (self.call_unsafe_limited(1, **kwargs) or [default])[0] return (self.call_unsafe_limited(1, **kwargs) or [default])[0]
def call_unsafe(self, **kwargs): def call_unsafe(self, **kwargs) -> typing.List[typing.Any]:
return self._call(kwargs, False, None) return self._call(kwargs, False, None)
def call_unsafe_limited(self, maximum, **kwargs): def call_unsafe_limited(self, maximum: int, **kwargs
) -> typing.List[typing.Any]:
return self._call(kwargs, False, maximum) return self._call(kwargs, False, maximum)
def _call(self, kwargs, safe, maximum): def _call(self, kwargs: dict, safe: bool, maximum: int
) -> typing.List[typing.Any]:
event_path = self._get_path() event_path = self._get_path()
self.log.trace("calling event: \"%s\" (params: %s)", self.log.trace("calling event: \"%s\" (params: %s)",
[event_path, kwargs]) [event_path, kwargs])
@ -240,13 +192,13 @@ class EventHook(object):
return returns return returns
def get_child(self, child_name): def get_child(self, child_name: str) -> "EventHook":
child_name_lower = child_name.lower() child_name_lower = child_name.lower()
if not child_name_lower in self._children: if not child_name_lower in self._children:
self._children[child_name_lower] = EventHook(self.log, self._children[child_name_lower] = EventHook(self.log,
child_name_lower, self) child_name_lower, self)
return self._children[child_name_lower] return self._children[child_name_lower]
def remove_child(self, child_name): def remove_child(self, child_name: str):
child_name_lower = child_name.lower() child_name_lower = child_name.lower()
if child_name_lower in self._children: if child_name_lower in self._children:
del self._children[child_name_lower] del self._children[child_name_lower]
@ -256,11 +208,11 @@ class EventHook(object):
self.parent.remove_child(self.name) self.parent.remove_child(self.name)
self.parent.check_purge() self.parent.check_purge()
def remove_context(self, context): def remove_context(self, context: str):
del self._context_hooks[context] del self._context_hooks[context]
def has_context(self, context): def has_context(self, context: str) -> bool:
return context in self._context_hooks return context in self._context_hooks
def purge_context(self, context): def purge_context(self, context: str):
if self.has_context(context): if self.has_context(context):
self.remove_context(context) self.remove_context(context)
@ -268,10 +220,69 @@ class EventHook(object):
child = self.get_child(child_name) child = self.get_child(child_name)
child.purge_context(context) child.purge_context(context)
def get_hooks(self): def get_hooks(self) -> typing.List[EventCallback]:
return sorted(self._hooks + sum(self._context_hooks.values(), []), return sorted(self._hooks + sum(self._context_hooks.values(), []),
key=lambda e: e.priority) key=lambda e: e.priority)
def get_children(self): def get_children(self) -> typing.List["EventHook"]:
return list(self._children.keys()) return list(self._children.keys())
def is_empty(self): def is_empty(self) -> bool:
return len(self.get_hooks() + self.get_children()) == 0 return len(self.get_hooks() + self.get_children()) == 0
class MultipleEventHook(object):
def __init__(self):
self._event_hooks = set([])
def _add(self, event_hook: EventHook):
self._event_hooks.add(event_hook)
def hook(self, function: CALLBACK_TYPE, **kwargs):
for event_hook in self._event_hooks:
event_hook.hook(function, **kwargs)
def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]:
returns = []
for event_hook in self._event_hooks:
returns.append(event_hook.call_limited(maximum, **kwargs))
return returns
def call(self, **kwargs) -> typing.List[typing.Any]:
returns = []
for event_hook in self._event_hooks:
returns.append(event_hook.call(**kwargs))
return returns
class EventHookContext(object):
def __init__(self, parent, context):
self._parent = parent
self.context = context
def hook(self, function: CALLBACK_TYPE, priority: int = DEFAULT_PRIORITY,
replay: bool = False, **kwargs) -> EventCallback:
return self._parent._context_hook(self.context, function, priority,
replay, kwargs)
def unhook(self, callback: EventCallback):
self._parent.unhook(callback)
def on(self, subevent: str, *extra_subevents,
delimiter: str = DEFAULT_EVENT_DELIMITER) -> EventHook:
return self._parent._context_on(self.context, subevent,
extra_subevents, delimiter)
def call_for_result(self, default=None, **kwargs) -> typing.Any:
return self._parent.call_for_result(default, **kwargs)
def assure_call(self, **kwargs):
self._parent.assure_call(**kwargs)
def call(self, **kwargs) -> typing.List[typing.Any]:
return self._parent.call(**kwargs)
def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]:
return self._parent.call_limited(maximum, **kwargs)
def call_unsafe_for_result(self, default=None, **kwargs) -> typing.Any:
return self._parent.call_unsafe_for_result(default, **kwargs)
def call_unsafe(self, **kwargs) -> typing.List[typing.Any]:
return self._parent.call_unsafe(**kwargs)
def call_unsafe_limited(self, maximum: int, **kwargs
) -> typing.List[typing.Any]:
return self._parent.call_unsafe_limited(maximum, **kwargs)
def get_hooks(self) -> typing.List[EventCallback]:
return self._parent.get_hooks()
def get_children(self) -> typing.List[EventHook]:
return self._parent.get_children()

View file

@ -1,28 +1,18 @@
import typing
class ExportsContext(object):
def __init__(self, parent, context):
self._parent = parent
self.context = context
def add(self, setting, value):
self._parent._context_add(self.context, setting, value)
def get_all(self, setting):
return self._parent.get_all(setting)
class Exports(object): class Exports(object):
def __init__(self): def __init__(self):
self._exports = {} self._exports = {}
self._context_exports = {} self._context_exports = {}
def new_context(self, context): def new_context(self, context: str) -> "ExportsContext":
return ExportsContext(self, context) return ExportsContext(self, context)
def add(self, setting, value): def add(self, setting: str, value: typing.Any):
self._add(None, setting, value) self._add(None, setting, value)
def _context_add(self, context, setting, value): def _context_add(self, context: str, setting: str, value: typing.Any):
self._add(context, setting, value) self._add(context, setting, value)
def _add(self, context, setting, value): def _add(self, context: str, setting: str, value: typing.Any):
if context == None: if context == None:
if not setting in self_exports: if not setting in self_exports:
self._exports[setting] = [] self._exports[setting] = []
@ -34,11 +24,21 @@ class Exports(object):
self._context_exports[context][setting] = [] self._context_exports[context][setting] = []
self._context_exports[context][setting].append(value) self._context_exports[context][setting].append(value)
def get_all(self, setting): def get_all(self, setting: str) -> typing.List[typing.Any]:
return self._exports.get(setting, []) + sum([ return self._exports.get(setting, []) + sum([
exports.get(setting, []) for exports in exports.get(setting, []) for exports in
self._context_exports.values()], []) self._context_exports.values()], [])
def purge_context(self, context): def purge_context(self, context: str):
if context in self._context_exports: if context in self._context_exports:
del self._context_exports[context] del self._context_exports[context]
class ExportsContext(object):
def __init__(self, parent: Exports, context: str):
self._parent = parent
self.context = context
def add(self, setting: str, value: typing.Any):
self._parent._context_add(self.context, setting, value)
def get_all(self, setting: str) -> typing.List[typing.Any]:
return self._parent.get_all(setting)

View file

@ -1,4 +1,4 @@
import os, select, socket, sys, threading, time, traceback, uuid import os, select, socket, sys, threading, time, traceback, typing, uuid
from src import EventManager, Exports, IRCServer, Logging, ModuleManager from src import EventManager, Exports, IRCServer, Logging, ModuleManager
from src import Socket, utils from src import Socket, utils
@ -28,14 +28,15 @@ class Bot(object):
self._trigger_functions = [] self._trigger_functions = []
def trigger(self, func=None): def trigger(self, func: typing.Callable[[], typing.Any]=None):
self.lock.acquire() self.lock.acquire()
if func: if func:
self._trigger_functions.append(func) self._trigger_functions.append(func)
self._trigger_client.send(b"TRIGGER") self._trigger_client.send(b"TRIGGER")
self.lock.release() self.lock.release()
def add_server(self, server_id, connect=True): def add_server(self, server_id: int, connect: bool = True
) -> typing.Optional[IRCServer.Server]:
(_, alias, hostname, port, password, ipv4, tls, bindhost, nickname, (_, alias, hostname, port, password, ipv4, tls, bindhost, nickname,
username, realname) = self.database.servers.get(server_id) username, realname) = self.database.servers.get(server_id)
@ -49,20 +50,20 @@ class Bot(object):
self.connect(new_server) self.connect(new_server)
return new_server return new_server
def add_socket(self, sock): def add_socket(self, sock: socket.socket):
self.other_sockets[sock.fileno()] = sock self.other_sockets[sock.fileno()] = sock
self.poll.register(sock.fileno(), select.EPOLLIN) self.poll.register(sock.fileno(), select.EPOLLIN)
def remove_socket(self, sock): def remove_socket(self, sock: socket.socket):
del self.other_sockets[sock.fileno()] del self.other_sockets[sock.fileno()]
self.poll.unregister(sock.fileno()) self.poll.unregister(sock.fileno())
def get_server(self, id): def get_server(self, id: int) -> typing.Optional[IRCServer.Server]:
for server in self.servers.values(): for server in self.servers.values():
if server.id == id: if server.id == id:
return server return server
def connect(self, server): def connect(self, server: IRCServer.Server) -> bool:
try: try:
server.connect() server.connect()
except: except:
@ -73,7 +74,7 @@ class Bot(object):
self.poll.register(server.fileno(), select.EPOLLOUT) self.poll.register(server.fileno(), select.EPOLLOUT)
return True return True
def next_send(self): def next_send(self) -> typing.Optional[float]:
next = None next = None
for server in self.servers.values(): for server in self.servers.values():
timeout = server.send_throttle_timeout() timeout = server.send_throttle_timeout()
@ -81,7 +82,7 @@ class Bot(object):
next = timeout next = timeout
return next return next
def next_ping(self): def next_ping(self) -> typing.Optional[float]:
timeouts = [] timeouts = []
for server in self.servers.values(): for server in self.servers.values():
timeout = server.until_next_ping() timeout = server.until_next_ping()
@ -90,7 +91,8 @@ class Bot(object):
if not timeouts: if not timeouts:
return None return None
return min(timeouts) return min(timeouts)
def next_read_timeout(self):
def next_read_timeout(self) -> typing.Optional[float]:
timeouts = [] timeouts = []
for server in self.servers.values(): for server in self.servers.values():
timeouts.append(server.until_read_timeout()) timeouts.append(server.until_read_timeout())
@ -98,7 +100,7 @@ class Bot(object):
return None return None
return min(timeouts) return min(timeouts)
def get_poll_timeout(self): def get_poll_timeout(self) -> float:
timeouts = [] timeouts = []
timeouts.append(self._timers.next()) timeouts.append(self._timers.next())
timeouts.append(self.next_send()) timeouts.append(self.next_send())
@ -107,15 +109,15 @@ class Bot(object):
timeouts.append(self.cache.next_expiration()) timeouts.append(self.cache.next_expiration())
return min([timeout for timeout in timeouts if not timeout == None]) return min([timeout for timeout in timeouts if not timeout == None])
def register_read(self, server): def register_read(self, server: IRCServer.Server):
self.poll.modify(server.fileno(), select.EPOLLIN) self.poll.modify(server.fileno(), select.EPOLLIN)
def register_write(self, server): def register_write(self, server: IRCServer.Server):
self.poll.modify(server.fileno(), select.EPOLLOUT) self.poll.modify(server.fileno(), select.EPOLLOUT)
def register_both(self, server): def register_both(self, server: IRCServer.Server):
self.poll.modify(server.fileno(), self.poll.modify(server.fileno(),
select.EPOLLIN|select.EPOLLOUT) select.EPOLLIN|select.EPOLLOUT)
def disconnect(self, server): def disconnect(self, server: IRCServer.Server):
try: try:
self.poll.unregister(server.fileno()) self.poll.unregister(server.fileno())
except FileNotFoundError: except FileNotFoundError:
@ -123,23 +125,25 @@ class Bot(object):
del self.servers[server.fileno()] del self.servers[server.fileno()]
@utils.hook("timer.reconnect") @utils.hook("timer.reconnect")
def reconnect(self, event): def reconnect(self, event: EventManager.Event):
server = self.add_server(event["server_id"], False) server = self.add_server(event["server_id"], False)
if self.connect(server): if self.connect(server):
self.servers[server.fileno()] = server self.servers[server.fileno()] = server
else: else:
event["timer"].redo() event["timer"].redo()
def set_setting(self, setting, value): def set_setting(self, setting: str, value: typing.Any):
self.database.bot_settings.set(setting, value) self.database.bot_settings.set(setting, value)
def get_setting(self, setting, default=None): def get_setting(self, setting: str, default: typing.Any=None) -> typing.Any:
return self.database.bot_settings.get(setting, default) return self.database.bot_settings.get(setting, default)
def find_settings(self, pattern, default=[]): def find_settings(self, pattern: str, default: typing.Any=[]
) -> typing.List[typing.Any]:
return self.database.bot_settings.find(pattern, default) return self.database.bot_settings.find(pattern, default)
def find_settings_prefix(self, prefix, default=[]): def find_settings_prefix(self, prefix: str, default: typing.Any=[]
) -> typing.List[typing.Any]:
return self.database.bot_settings.find_prefix( return self.database.bot_settings.find_prefix(
prefix, default) prefix, default)
def del_setting(self, setting): def del_setting(self, setting: str):
self.database.bot_settings.delete(setting) self.database.bot_settings.delete(setting)
def run(self): def run(self):

View file

@ -1,8 +1,9 @@
import re import re, typing
from src import utils from src import IRCBot, utils
class BufferLine(object): class BufferLine(object):
def __init__(self, sender, message, action, tags, from_self, method): def __init__(self, sender: str, message: str, action: bool, tags: dict,
from_self: bool, method: str):
self.sender = sender self.sender = sender
self.message = message self.message = message
self.action = action self.action = action
@ -11,35 +12,39 @@ class BufferLine(object):
self.method = method self.method = method
class Buffer(object): class Buffer(object):
def __init__(self, bot, server): def __init__(self, bot: "IRCBot.Bot", server: "IRCServer.Server"):
self.bot = bot self.bot = bot
self.server = server self.server = server
self.lines = [] self.lines = []
self.max_lines = 64 self.max_lines = 64
self._skip_next = False self._skip_next = False
def _add_message(self, sender, message, action, tags, from_self, method): def _add_message(self, sender: str, message: str, action: bool, tags: dict,
from_self: bool, method: str):
if not self._skip_next: if not self._skip_next:
line = BufferLine(sender, message, action, tags, from_self, method) line = BufferLine(sender, message, action, tags, from_self, method)
self.lines.insert(0, line) self.lines.insert(0, line)
if len(self.lines) > self.max_lines: if len(self.lines) > self.max_lines:
self.lines.pop() self.lines.pop()
self._skip_next = False self._skip_next = False
def add_message(self, sender, message, action, tags, from_self=False): def add_message(self, sender: str, message: str, action: bool, tags: dict,
from_self: bool=False):
self._add_message(sender, message, action, tags, from_self, "PRIVMSG") self._add_message(sender, message, action, tags, from_self, "PRIVMSG")
def add_notice(self, sender, message, tags, from_self=False): def add_notice(self, sender: str, message: str, tags: dict,
from_self: bool=False):
self._add_message(sender, message, False, tags, from_self, "NOTICE") self._add_message(sender, message, False, tags, from_self, "NOTICE")
def get(self, index=0, **kwargs): def get(self, index: int=0, **kwargs) -> typing.Optional[BufferLine]:
from_self = kwargs.get("from_self", True) from_self = kwargs.get("from_self", True)
for line in self.lines: for line in self.lines:
if line.from_self and not from_self: if line.from_self and not from_self:
continue continue
return line return line
def find(self, pattern, **kwargs): def find(self, pattern: typing.Union[str, typing.Pattern[str]], **kwargs
) -> typing.Optional[BufferLine]:
from_self = kwargs.get("from_self", True) from_self = kwargs.get("from_self", True)
for_user = kwargs.get("for_user", "") for_user = kwargs.get("for_user", "")
for_user = utils.irc.lower(self.server, for_user for_user = utils.irc.lower(self.server.case_mapping, for_user
) if for_user else None ) if for_user else None
not_pattern = kwargs.get("not_pattern", None) not_pattern = kwargs.get("not_pattern", None)
for line in self.lines: for line in self.lines:
@ -48,8 +53,8 @@ class Buffer(object):
elif re.search(pattern, line.message): elif re.search(pattern, line.message):
if not_pattern and re.search(not_pattern, line.message): if not_pattern and re.search(not_pattern, line.message):
continue continue
if for_user and not utils.irc.lower(self.server, line.sender if for_user and not utils.irc.lower(self.server.case_mapping,
) == for_user: line.sender) == for_user:
continue continue
return line return line
def skip_next(self): def skip_next(self):

View file

@ -1,9 +1,10 @@
import uuid import typing, uuid
from src import IRCBuffer, IRCObject, utils from src import IRCBot, IRCBuffer, IRCObject, IRCServer, IRCUser, utils
class Channel(IRCObject.Object): class Channel(IRCObject.Object):
def __init__(self, name, id, server, bot): def __init__(self, name: str, id, server: "IRCServer.Server",
self.name = utils.irc.lower(server, name) bot: "IRCBot.Bot"):
self.name = utils.irc.lower(server.case_mapping, name)
self.id = id self.id = id
self.server = server self.server = server
self.bot = bot self.bot = bot
@ -18,23 +19,24 @@ class Channel(IRCObject.Object):
self.created_timestamp = None self.created_timestamp = None
self.buffer = IRCBuffer.Buffer(bot, server) self.buffer = IRCBuffer.Buffer(bot, server)
def __repr__(self): def __repr__(self) -> str:
return "IRCChannel.Channel(%s|%s)" % (self.server.name, self.name) return "IRCChannel.Channel(%s|%s)" % (self.server.name, self.name)
def __str__(self): def __str__(self) -> str:
return self.name return self.name
def set_topic(self, topic): def set_topic(self, topic: str):
self.topic = topic self.topic = topic
def set_topic_setter(self, nickname, username=None, hostname=None): def set_topic_setter(self, nickname: str, username: str=None,
hostname: str=None):
self.topic_setter_nickname = nickname self.topic_setter_nickname = nickname
self.topic_setter_username = username self.topic_setter_username = username
self.topic_setter_hostname = hostname self.topic_setter_hostname = hostname
def set_topic_time(self, unix_timestamp): def set_topic_time(self, unix_timestamp: int):
self.topic_time = unix_timestamp self.topic_time = unix_timestamp
def add_user(self, user): def add_user(self, user: IRCUser.User):
self.users.add(user) self.users.add(user)
def remove_user(self, user): def remove_user(self, user: IRCUser.User):
self.users.remove(user) self.users.remove(user)
for mode in list(self.modes.keys()): for mode in list(self.modes.keys()):
if mode in self.server.prefix_modes and user in self.modes[mode]: if mode in self.server.prefix_modes and user in self.modes[mode]:
@ -43,10 +45,10 @@ class Channel(IRCObject.Object):
del self.modes[mode] del self.modes[mode]
if user in self.user_modes: if user in self.user_modes:
del self.user_modes[user] del self.user_modes[user]
def has_user(self, user): def has_user(self, user: IRCUser.User) -> bool:
return user in self.users return user in self.users
def add_mode(self, mode, arg=None): def add_mode(self, mode: str, arg: str=None):
if not mode in self.modes: if not mode in self.modes:
self.modes[mode] = set([]) self.modes[mode] = set([])
if arg: if arg:
@ -59,7 +61,7 @@ class Channel(IRCObject.Object):
self.user_modes[user].add(mode) self.user_modes[user].add(mode)
else: else:
self.modes[mode].add(arg.lower()) self.modes[mode].add(arg.lower())
def remove_mode(self, mode, arg=None): def remove_mode(self, mode: str, arg: str=None):
if not arg: if not arg:
del self.modes[mode] del self.modes[mode]
else: else:
@ -76,63 +78,70 @@ class Channel(IRCObject.Object):
self.modes[mode].discard(arg.lower()) self.modes[mode].discard(arg.lower())
if not len(self.modes[mode]): if not len(self.modes[mode]):
del self.modes[mode] del self.modes[mode]
def change_mode(self, remove, mode, arg=None): def change_mode(self, remove: bool, mode: str, arg: str=None):
if remove: if remove:
self.remove_mode(mode, arg) self.remove_mode(mode, arg)
else: else:
self.add_mode(mode, arg) self.add_mode(mode, arg)
def set_setting(self, setting, value): def set_setting(self, setting: str, value: typing.Any):
self.bot.database.channel_settings.set(self.id, setting, value) self.bot.database.channel_settings.set(self.id, setting, value)
def get_setting(self, setting, default=None): def get_setting(self, setting: str, default: typing.Any=None
) -> typing.Any:
return self.bot.database.channel_settings.get(self.id, setting, return self.bot.database.channel_settings.get(self.id, setting,
default) default)
def find_settings(self, pattern, default=[]): def find_settings(self, pattern: str, default: typing.Any=[]
) -> typing.List[typing.Any]:
return self.bot.database.channel_settings.find(self.id, pattern, return self.bot.database.channel_settings.find(self.id, pattern,
default) default)
def find_settings_prefix(self, prefix, default=[]): def find_settings_prefix(self, prefix: str, default: typing.Any=[]
) -> typing.List[typing.Any]:
return self.bot.database.channel_settings.find_prefix(self.id, return self.bot.database.channel_settings.find_prefix(self.id,
prefix, default) prefix, default)
def del_setting(self, setting): def del_setting(self, setting: str):
self.bot.database.channel_settings.delete(self.id, setting) self.bot.database.channel_settings.delete(self.id, setting)
def set_user_setting(self, user_id, setting, value): def set_user_setting(self, user_id: int, setting: str, value: typing.Any):
self.bot.database.user_channel_settings.set(user_id, self.id, self.bot.database.user_channel_settings.set(user_id, self.id,
setting, value) setting, value)
def get_user_setting(self, user_id, setting, default=None): def get_user_setting(self, user_id: int, setting: str,
default: typing.Any=None) -> typing.Any:
return self.bot.database.user_channel_settings.get(user_id, return self.bot.database.user_channel_settings.get(user_id,
self.id, setting, default) self.id, setting, default)
def find_user_settings(self, user_i, pattern, default=[]): def find_user_settings(self, user_id: int, pattern: str,
default: typing.Any=[]) -> typing.List[typing.Any]:
return self.bot.database.user_channel_settings.find(user_id, return self.bot.database.user_channel_settings.find(user_id,
self.id, pattern, default) self.id, pattern, default)
def find_user_settings_prefix(self, user_id, prefix, default=[]): def find_user_settings_prefix(self, user_id: int, prefix: str,
default: typing.Any=[]) -> typing.List[typing.Any]:
return self.bot.database.user_channel_settings.find_prefix( return self.bot.database.user_channel_settings.find_prefix(
user_id, self.id, prefix, default) user_id, self.id, prefix, default)
def del_user_setting(self, user_id, setting): def del_user_setting(self, user_id: int, setting: str):
self.bot.database.user_channel_settings.delete(user_id, self.id, self.bot.database.user_channel_settings.delete(user_id, self.id,
setting) setting)
def find_all_by_setting(self, setting, default=[]): def find_all_by_setting(self, setting: str, default: typing.Any=[]
) -> typing.List[typing.Any]:
return self.bot.database.user_channel_settings.find_all_by_setting( return self.bot.database.user_channel_settings.find_all_by_setting(
self.id, setting, default) self.id, setting, default)
def send_message(self, text, prefix=None, tags={}): def send_message(self, text: str, prefix: str=None, tags: dict={}):
self.server.send_message(self.name, text, prefix=prefix, tags=tags) self.server.send_message(self.name, text, prefix=prefix, tags=tags)
def send_notice(self, text, prefix=None, tags={}): def send_notice(self, text: str, prefix: str=None, tags: dict={}):
self.server.send_notice(self.name, text, prefix=prefix, tags=tags) self.server.send_notice(self.name, text, prefix=prefix, tags=tags)
def send_mode(self, mode=None, target=None): def send_mode(self, mode: str=None, target: str=None):
self.server.send_mode(self.name, mode, target) self.server.send_mode(self.name, mode, target)
def send_kick(self, target, reason=None): def send_kick(self, target: str, reason: str=None):
self.server.send_kick(self.name, target, reason) self.server.send_kick(self.name, target, reason)
def send_ban(self, hostmask): def send_ban(self, hostmask: str):
self.server.send_mode(self.name, "+b", hostmask) self.server.send_mode(self.name, "+b", hostmask)
def send_unban(self, hostmask): def send_unban(self, hostmask: str):
self.server.send_mode(self.name, "-b", hostmask) self.server.send_mode(self.name, "-b", hostmask)
def send_topic(self, topic): def send_topic(self, topic: str):
self.server.send_topic(self.name, topic) self.server.send_topic(self.name, topic)
def send_part(self, reason=None): def send_part(self, reason: str=None):
self.server.send_part(self.name, reason) self.server.send_part(self.name, reason)
def mode_or_above(self, user, mode): def mode_or_above(self, user: IRCUser.User, mode: str) -> bool:
mode_orders = list(self.server.prefix_modes) mode_orders = list(self.server.prefix_modes)
mode_index = mode_orders.index(mode) mode_index = mode_orders.index(mode)
for mode in mode_orders[:mode_index+1]: for mode in mode_orders[:mode_index+1]:
@ -140,8 +149,8 @@ class Channel(IRCObject.Object):
return True return True
return False return False
def has_mode(self, user, mode): def has_mode(self, user: IRCUser.User, mode: str) -> bool:
return user in self.modes.get(mode, []) return user in self.modes.get(mode, [])
def get_user_status(self, user): def get_user_status(self, user: IRCUser.User) -> typing.Set:
return self.user_modes.get(user, []) return self.user_modes.get(user, [])

View file

@ -1,5 +1,5 @@
import collections, socket, ssl, sys, time import collections, socket, ssl, sys, time, typing
from src import IRCChannel, IRCObject, IRCUser, utils from src import EventManager, IRCBot, IRCChannel, IRCObject, IRCUser, utils
THROTTLE_LINES = 4 THROTTLE_LINES = 4
THROTTLE_SECONDS = 1 THROTTLE_SECONDS = 1
@ -7,8 +7,12 @@ READ_TIMEOUT_SECONDS = 120
PING_INTERVAL_SECONDS = 30 PING_INTERVAL_SECONDS = 30
class Server(IRCObject.Object): class Server(IRCObject.Object):
def __init__(self, bot, events, id, alias, hostname, port, password, def __init__(self,
ipv4, tls, bindhost, nickname, username, realname): bot: "IRCBot.Bot",
events: EventManager.EventHook,
id: int, alias: str, hostname: str, port: int, password: str,
ipv4: bool, tls: bool, bindhost: str,
nickname: str, username: str, realname: str):
self.connected = False self.connected = False
self.bot = bot self.bot = bot
self.events = events self.events = events
@ -121,77 +125,80 @@ class Server(IRCObject.Object):
except: except:
pass pass
def set_setting(self, setting, value): def set_setting(self, setting: str, value: typing.Any):
self.bot.database.server_settings.set(self.id, setting, self.bot.database.server_settings.set(self.id, setting,
value) value)
def get_setting(self, setting, default=None): def get_setting(self, setting: str, default: typing.Any=None):
return self.bot.database.server_settings.get(self.id, return self.bot.database.server_settings.get(self.id,
setting, default) setting, default)
def find_settings(self, pattern, default=[]): def find_settings(self, pattern: str, default: typing.Any=[]):
return self.bot.database.server_settings.find(self.id, return self.bot.database.server_settings.find(self.id,
pattern, default) pattern, default)
def find_settings_prefix(self, prefix, default=[]): def find_settings_prefix(self, prefix: str, default: typing.Any=[]):
return self.bot.database.server_settings.find_prefix( return self.bot.database.server_settings.find_prefix(
self.id, prefix, default) self.id, prefix, default)
def del_setting(self, setting): def del_setting(self, setting: str):
self.bot.database.server_settings.delete(self.id, setting) self.bot.database.server_settings.delete(self.id, setting)
def get_user_setting(self, nickname, setting, default=None): def get_user_setting(self, nickname: str, setting: str,
default: typing.Any=None):
user_id = self.get_user_id(nickname) user_id = self.get_user_id(nickname)
return self.bot.database.user_settings.get(user_id, setting, default) return self.bot.database.user_settings.get(user_id, setting, default)
def set_user_setting(self, nickname, setting, value): def set_user_setting(self, nickname: str, setting: str, value: typing.Any):
user_id = self.get_user_id(nickname) user_id = self.get_user_id(nickname)
self.bot.database.user_settings.set(user_id, setting, value) self.bot.database.user_settings.set(user_id, setting, value)
def get_all_user_settings(self, setting, default=[]): def get_all_user_settings(self, setting: str, default: typing.Any=[]):
return self.bot.database.user_settings.find_all_by_setting( return self.bot.database.user_settings.find_all_by_setting(
self.id, setting, default) self.id, setting, default)
def find_all_user_channel_settings(self, setting, default=[]): def find_all_user_channel_settings(self, setting: str,
default: typing.Any=[]):
return self.bot.database.user_channel_settings.find_all_by_setting( return self.bot.database.user_channel_settings.find_all_by_setting(
self.id, setting, default) self.id, setting, default)
def set_own_nickname(self, nickname): def set_own_nickname(self, nickname: str):
self.nickname = nickname self.nickname = nickname
self.nickname_lower = utils.irc.lower(self, nickname) self.nickname_lower = utils.irc.lower(self.case_mapping, nickname)
def is_own_nickname(self, nickname): def is_own_nickname(self, nickname: str):
return utils.irc.equals(self, nickname, self.nickname) return utils.irc.equals(self, nickname, self.nickname)
def add_own_mode(self, mode, arg=None): def add_own_mode(self, mode: str, arg: str=None):
self.own_modes[mode] = arg self.own_modes[mode] = arg
def remove_own_mode(self, mode): def remove_own_mode(self, mode: str):
del self.own_modes[mode] del self.own_modes[mode]
def change_own_mode(self, remove, mode, arg=None): def change_own_mode(self, remove: bool, mode: str, arg: str=None):
if remove: if remove:
self.remove_own_mode(mode) self.remove_own_mode(mode)
else: else:
self.add_own_mode(mode, arg) self.add_own_mode(mode, arg)
def has_user(self, nickname): def has_user(self, nickname: str):
return utils.irc.lower(self, nickname) in self.users return utils.irc.lower(self.case_mapping, nickname) in self.users
def get_user(self, nickname, create=True): def get_user(self, nickname: str, create: bool=True):
if not self.has_user(nickname) and create: if not self.has_user(nickname) and create:
user_id = self.get_user_id(nickname) user_id = self.get_user_id(nickname)
new_user = IRCUser.User(nickname, user_id, self, self.bot) new_user = IRCUser.User(nickname, user_id, self, self.bot)
self.events.on("new.user").call(user=new_user, server=self) self.events.on("new.user").call(user=new_user, server=self)
self.users[new_user.nickname_lower] = new_user self.users[new_user.nickname_lower] = new_user
self.new_users.add(new_user) self.new_users.add(new_user)
return self.users.get(utils.irc.lower(self, nickname), None) return self.users.get(utils.irc.lower(self.case_mapping, nickname),
def get_user_id(self, nickname): None)
def get_user_id(self, nickname: str):
self.bot.database.users.add(self.id, nickname) self.bot.database.users.add(self.id, nickname)
return self.bot.database.users.get_id(self.id, nickname) return self.bot.database.users.get_id(self.id, nickname)
def remove_user(self, user): def remove_user(self, user: IRCUser.User):
del self.users[user.nickname_lower] del self.users[user.nickname_lower]
for channel in user.channels: for channel in user.channels:
channel.remove_user(user) channel.remove_user(user)
def change_user_nickname(self, old_nickname, new_nickname): def change_user_nickname(self, old_nickname: str, new_nickname: str):
user = self.users.pop(utils.irc.lower(self, old_nickname)) user = self.users.pop(utils.irc.lower(self.case_mapping, old_nickname))
user._id = self.get_user_id(new_nickname) user._id = self.get_user_id(new_nickname)
self.users[utils.irc.lower(self, new_nickname)] = user self.users[utils.irc.lower(self.case_mapping, new_nickname)] = user
def has_channel(self, channel_name): def has_channel(self, channel_name: str):
return channel_name[0] in self.channel_types and utils.irc.lower( return channel_name[0] in self.channel_types and utils.irc.lower(
self, channel_name) in self.channels self.case_mapping, channel_name) in self.channels
def get_channel(self, channel_name): def get_channel(self, channel_name: str):
if not self.has_channel(channel_name): if not self.has_channel(channel_name):
channel_id = self.get_channel_id(channel_name) channel_id = self.get_channel_id(channel_name)
new_channel = IRCChannel.Channel(channel_name, channel_id, new_channel = IRCChannel.Channel(channel_name, channel_id,
@ -199,15 +206,15 @@ class Server(IRCObject.Object):
self.events.on("new.channel").call(channel=new_channel, self.events.on("new.channel").call(channel=new_channel,
server=self) server=self)
self.channels[new_channel.name] = new_channel self.channels[new_channel.name] = new_channel
return self.channels[utils.irc.lower(self, channel_name)] return self.channels[utils.irc.lower(self.case_mapping, channel_name)]
def get_channel_id(self, channel_name): def get_channel_id(self, channel_name: str):
self.bot.database.channels.add(self.id, channel_name) self.bot.database.channels.add(self.id, channel_name)
return self.bot.database.channels.get_id(self.id, channel_name) return self.bot.database.channels.get_id(self.id, channel_name)
def remove_channel(self, channel): def remove_channel(self, channel: IRCChannel.Channel):
for user in channel.users: for user in channel.users:
user.part_channel(channel) user.part_channel(channel)
del self.channels[channel.name] del self.channels[channel.name]
def parse_data(self, line): def parse_data(self, line: str):
if not line: if not line:
return return
self.events.on("raw").call_unsafe(server=self, line=line) self.events.on("raw").call_unsafe(server=self, line=line)
@ -271,7 +278,7 @@ class Server(IRCObject.Object):
def read_timed_out(self): def read_timed_out(self):
return self.until_read_timeout == 0 return self.until_read_timeout == 0
def send(self, data): def send(self, data: str):
returned = self.events.on("preprocess.send").call_unsafe_for_result( returned = self.events.on("preprocess.send").call_unsafe_for_result(
server=self, line=data) server=self, line=data)
line = returned or data line = returned or data
@ -314,16 +321,16 @@ class Server(IRCObject.Object):
time_left = time_left-now time_left = time_left-now
return time_left return time_left
def send_user(self, username, realname): def send_user(self, username: str, realname: str):
self.send("USER %s 0 * :%s" % (username, realname)) self.send("USER %s 0 * :%s" % (username, realname))
def send_nick(self, nickname): def send_nick(self, nickname: str):
self.send("NICK %s" % nickname) self.send("NICK %s" % nickname)
def send_capibility_ls(self): def send_capibility_ls(self):
self.send("CAP LS 302") self.send("CAP LS 302")
def queue_capability(self, capability): def queue_capability(self, capability: str):
self._capability_queue.add(capability) self._capability_queue.add(capability)
def queue_capabilities(self, capabilities): def queue_capabilities(self, capabilities: typing.List[str]):
self._capability_queue.update(capabilities) self._capability_queue.update(capabilities)
def send_capability_queue(self): def send_capability_queue(self):
if self.has_capability_queue(): if self.has_capability_queue():
@ -332,46 +339,46 @@ class Server(IRCObject.Object):
self.send_capability_request(capabilities) self.send_capability_request(capabilities)
def has_capability_queue(self): def has_capability_queue(self):
return bool(len(self._capability_queue)) return bool(len(self._capability_queue))
def send_capability_request(self, capability): def send_capability_request(self, capability: str):
self.send("CAP REQ :%s" % capability) self.send("CAP REQ :%s" % capability)
def send_capability_end(self): def send_capability_end(self):
self.send("CAP END") self.send("CAP END")
def send_authenticate(self, text): def send_authenticate(self, text: str):
self.send("AUTHENTICATE %s" % text) self.send("AUTHENTICATE %s" % text)
def send_starttls(self): def send_starttls(self):
self.send("STARTTLS") self.send("STARTTLS")
def waiting_for_capabilities(self): def waiting_for_capabilities(self):
return bool(len(self._capabilities_waiting)) return bool(len(self._capabilities_waiting))
def wait_for_capability(self, capability): def wait_for_capability(self, capability: str):
self._capabilities_waiting.add(capability) self._capabilities_waiting.add(capability)
def capability_done(self, capability): def capability_done(self, capability: str):
self._capabilities_waiting.remove(capability) self._capabilities_waiting.remove(capability)
if not self._capabilities_waiting: if not self._capabilities_waiting:
self.send_capability_end() self.send_capability_end()
def send_pass(self, password): def send_pass(self, password: str):
self.send("PASS %s" % password) self.send("PASS %s" % password)
def send_ping(self, nonce="hello"): def send_ping(self, nonce: str="hello"):
self.send("PING :%s" % nonce) self.send("PING :%s" % nonce)
def send_pong(self, nonce="hello"): def send_pong(self, nonce: str="hello"):
self.send("PONG :%s" % nonce) self.send("PONG :%s" % nonce)
def try_rejoin(self, event): def try_rejoin(self, event: EventManager.Event):
if event["server_id"] == self.id and event["channel_name" if event["server_id"] == self.id and event["channel_name"
] in self.attempted_join: ] in self.attempted_join:
self.send_join(event["channel_name"], event["key"]) self.send_join(event["channel_name"], event["key"])
def send_join(self, channel_name, key=None): def send_join(self, channel_name: str, key: str=None):
self.send("JOIN %s%s" % (channel_name, self.send("JOIN %s%s" % (channel_name,
"" if key == None else " %s" % key)) "" if key == None else " %s" % key))
def send_part(self, channel_name, reason=None): def send_part(self, channel_name: str, reason: str=None):
self.send("PART %s%s" % (channel_name, self.send("PART %s%s" % (channel_name,
"" if reason == None else " %s" % reason)) "" if reason == None else " %s" % reason))
def send_quit(self, reason="Leaving"): def send_quit(self, reason: str="Leaving"):
self.send("QUIT :%s" % reason) self.send("QUIT :%s" % reason)
def _tag_str(self, tags): def _tag_str(self, tags: dict):
tag_str = "" tag_str = ""
for tag, value in tags.items(): for tag, value in tags.items():
if tag_str: if tag_str:
@ -383,7 +390,8 @@ class Server(IRCObject.Object):
tag_str = "@%s " % tag_str tag_str = "@%s " % tag_str
return tag_str return tag_str
def send_message(self, target, message, prefix=None, tags={}): def send_message(self, target: str, message: str, prefix: str=None,
tags: dict={}):
full_message = message if not prefix else prefix+message full_message = message if not prefix else prefix+message
self.send("%sPRIVMSG %s :%s" % (self._tag_str(tags), target, self.send("%sPRIVMSG %s :%s" % (self._tag_str(tags), target,
full_message)) full_message))
@ -408,7 +416,8 @@ class Server(IRCObject.Object):
message=full_message, message_split=full_message_split, message=full_message, message_split=full_message_split,
user=user, action=action, server=self) user=user, action=action, server=self)
def send_notice(self, target, message, prefix=None, tags={}): def send_notice(self, target: str, message: str, prefix: str=None,
tags: dict={}):
full_message = message if not prefix else prefix+message full_message = message if not prefix else prefix+message
self.send("%sNOTICE %s :%s" % (self._tag_str(tags), target, self.send("%sNOTICE %s :%s" % (self._tag_str(tags), target,
full_message)) full_message))
@ -419,31 +428,31 @@ class Server(IRCObject.Object):
self.get_user(target).buffer.add_notice(None, message, tags, self.get_user(target).buffer.add_notice(None, message, tags,
True) True)
def send_mode(self, target, mode=None, args=None): def send_mode(self, target: str, mode: str=None, args: str=None):
self.send("MODE %s%s%s" % (target, "" if mode == None else " %s" % mode, self.send("MODE %s%s%s" % (target, "" if mode == None else " %s" % mode,
"" if args == None else " %s" % args)) "" if args == None else " %s" % args))
def send_topic(self, channel_name, topic): def send_topic(self, channel_name: str, topic: str):
self.send("TOPIC %s :%s" % (channel_name, topic)) self.send("TOPIC %s :%s" % (channel_name, topic))
def send_kick(self, channel_name, target, reason=None): def send_kick(self, channel_name: str, target: str, reason: str=None):
self.send("KICK %s %s%s" % (channel_name, target, self.send("KICK %s %s%s" % (channel_name, target,
"" if reason == None else " :%s" % reason)) "" if reason == None else " :%s" % reason))
def send_names(self, channel_name): def send_names(self, channel_name: str):
self.send("NAMES %s" % channel_name) self.send("NAMES %s" % channel_name)
def send_list(self, search_for=None): def send_list(self, search_for: str=None):
self.send( self.send(
"LIST%s" % "" if search_for == None else " %s" % search_for) "LIST%s" % "" if search_for == None else " %s" % search_for)
def send_invite(self, target, channel_name): def send_invite(self, target: str, channel_name: str):
self.send("INVITE %s %s" % (target, channel_name)) self.send("INVITE %s %s" % (target, channel_name))
def send_whois(self, target): def send_whois(self, target: str):
self.send("WHOIS %s" % target) self.send("WHOIS %s" % target)
def send_whowas(self, target, amount=None, server=None): def send_whowas(self, target: str, amount: int=None, server: str=None):
self.send("WHOWAS %s%s%s" % (target, self.send("WHOWAS %s%s%s" % (target,
"" if amount == None else " %s" % amount, "" if amount == None else " %s" % amount,
"" if server == None else " :%s" % server)) "" if server == None else " :%s" % server))
def send_who(self, filter=None): def send_who(self, filter: str=None):
self.send("WHO%s" % ("" if filter == None else " %s" % filter)) self.send("WHO%s" % ("" if filter == None else " %s" % filter))
def send_whox(self, mask, filter, fields, label=None): def send_whox(self, mask: str, filter: str, fields: str, label: str=None):
self.send("WHO %s %s%%%s%s" % (mask, filter, fields, self.send("WHO %s %s%%%s%s" % (mask, filter, fields,
","+label if label else "")) ","+label if label else ""))

View file

@ -1,8 +1,9 @@
import uuid import typing, uuid
from src import IRCBuffer, IRCObject, utils from src import IRCBot, IRCChannel, IRCBuffer, IRCObject, IRCServer, utils
class User(IRCObject.Object): class User(IRCObject.Object):
def __init__(self, nickname, id, server, bot): def __init__(self, nickname: str, id: int, server: "IRCServer.Server",
bot: "IRCBot.Bot"):
self.server = server self.server = server
self.set_nickname(nickname) self.set_nickname(nickname)
self._id = id self._id = id
@ -20,46 +21,51 @@ class User(IRCObject.Object):
self.away = False self.away = False
self.buffer = IRCBuffer.Buffer(bot, server) self.buffer = IRCBuffer.Buffer(bot, server)
def __repr__(self): def __repr__(self) -> str:
return "IRCUser.User(%s|%s)" % (self.server.name, self.name) return "IRCUser.User(%s|%s)" % (self.server.name, self.name)
def __str__(self): def __str__(self) -> str:
return self.nickname return self.nickname
def get_id(self): def get_id(self)-> int:
return (self.identified_account_id_override or return (self.identified_account_id_override or
self.identified_account_id or self._id) self.identified_account_id or self._id)
def get_identified_account(self): def get_identified_account(self) -> str:
return (self.identified_account_override or self.identified_account) return (self.identified_account_override or self.identified_account)
def set_nickname(self, nickname): def set_nickname(self, nickname: str):
self.nickname = nickname self.nickname = nickname
self.nickname_lower = utils.irc.lower(self.server, nickname) self.nickname_lower = utils.irc.lower(self.server.case_mapping,
nickname)
self.name = self.nickname_lower self.name = self.nickname_lower
def join_channel(self, channel): def join_channel(self, channel: "IRCChannel.Channel"):
self.channels.add(channel) self.channels.add(channel)
def part_channel(self, channel): def part_channel(self, channel: "IRCChannel.Channel"):
self.channels.remove(channel) self.channels.remove(channel)
def set_setting(self, setting, value):
def set_setting(self, setting: str, value: typing.Any):
self.bot.database.user_settings.set(self.get_id(), setting, value) self.bot.database.user_settings.set(self.get_id(), setting, value)
def get_setting(self, setting, default=None): def get_setting(self, setting: str, default: typing.Any=None) -> typing.Any:
return self.bot.database.user_settings.get(self.get_id(), setting, return self.bot.database.user_settings.get(self.get_id(), setting,
default) default)
def find_settings(self, pattern, default=[]): def find_settings(self, pattern: str, default: typing.Any=[]
) -> typing.List[typing.Any]:
return self.bot.database.user_settings.find(self.get_id(), pattern, return self.bot.database.user_settings.find(self.get_id(), pattern,
default) default)
def find_settings_prefix(self, prefix, default=[]): def find_settings_prefix(self, prefix: str, default: typing.Any=[]
) -> typing.List[typing.Any]:
return self.bot.database.user_settings.find_prefix(self.get_id(), return self.bot.database.user_settings.find_prefix(self.get_id(),
prefix, default) prefix, default)
def del_setting(self, setting): def del_setting(self, setting):
self.bot.database.user_settings.delete(self.get_id(), setting) self.bot.database.user_settings.delete(self.get_id(), setting)
def get_channel_settings_per_setting(self, setting, default=[]): def get_channel_settings_per_setting(self, setting: str,
default: typing.Any=[]) -> typing.List[typing.Any]:
return self.bot.database.user_channel_settings.find_by_setting( return self.bot.database.user_channel_settings.find_by_setting(
self.get_id(), setting, default) self.get_id(), setting, default)
def send_message(self, message, prefix=None, tags={}): def send_message(self, message: str, prefix: str=None, tags: dict={}):
self.server.send_message(self.nickname, message, prefix=prefix, self.server.send_message(self.nickname, message, prefix=prefix,
tags=tags) tags=tags)
def send_notice(self, text, prefix=None, tags={}): def send_notice(self, text: str, prefix: str=None, tags: dict={}):
self.server.send_notice(self.nickname, text, prefix=prefix, tags=tags) self.server.send_notice(self.nickname, text, prefix=prefix, tags=tags)
def send_ctcp_response(self, command, args): def send_ctcp_response(self, command: str, args: str):
self.send_notice("\x01%s %s\x01" % (command, args)) self.send_notice("\x01%s %s\x01" % (command, args))

View file

@ -1,4 +1,4 @@
import logging, logging.handlers, os, sys, time import logging, logging.handlers, os, sys, time, typing
LEVELS = { LEVELS = {
"trace": logging.DEBUG-1, "trace": logging.DEBUG-1,
@ -23,7 +23,7 @@ class BitBotFormatter(logging.Formatter):
return s return s
class Log(object): class Log(object):
def __init__(self, level, location): def __init__(self, level: str, location: str):
logging.addLevelName(LEVELS["trace"], "TRACE") logging.addLevelName(LEVELS["trace"], "TRACE")
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
@ -49,17 +49,17 @@ class Log(object):
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
self.logger.addHandler(file_handler) self.logger.addHandler(file_handler)
def trace(self, message, params, **kwargs): def trace(self, message: str, params: typing.List, **kwargs):
self._log(message, params, LEVELS["trace"], kwargs) self._log(message, params, LEVELS["trace"], kwargs)
def debug(self, message, params, **kwargs): def debug(self, message: str, params: typing.List, **kwargs):
self._log(message, params, logging.DEBUG, kwargs) self._log(message, params, logging.DEBUG, kwargs)
def info(self, message, params, **kwargs): def info(self, message: str, params: typing.List, **kwargs):
self._log(message, params, logging.INFO, kwargs) self._log(message, params, logging.INFO, kwargs)
def warn(self, message, params, **kwargs): def warn(self, message: str, params: typing.List, **kwargs):
self._log(message, params, logging.WARN, kwargs) self._log(message, params, logging.WARN, kwargs)
def error(self, message, params, **kwargs): def error(self, message: str, params: typing.List, **kwargs):
self._log(message, params, logging.ERROR, kwargs) self._log(message, params, logging.ERROR, kwargs)
def critical(self, message, params, **kwargs): def critical(self, message: str, params: typing.List, **kwargs):
self._log(message, params, logging.CRITICAL, kwargs) self._log(message, params, logging.CRITICAL, kwargs)
def _log(self, message, params, level, kwargs): def _log(self, message: str, params: typing.List, level: int, kwargs: dict):
self.logger.log(level, message, *params, **kwargs) self.logger.log(level, message, *params, **kwargs)

View file

@ -1,8 +1,5 @@
import gc, glob, imp, io, inspect, os, sys, uuid import gc, glob, imp, io, inspect, os, sys, typing, uuid
from . import utils from src import Config, EventManager, Exports, IRCBot, Logging, Timers, utils
BITBOT_HOOKS_MAGIC = "__bitbot_hooks"
BITBOT_EXPORTS_MAGIC = "__bitbot_exports"
class ModuleException(Exception): class ModuleException(Exception):
pass pass
@ -22,7 +19,11 @@ class ModuleNotLoadedWarning(ModuleWarning):
pass pass
class BaseModule(object): class BaseModule(object):
def __init__(self, bot, events, exports, timers): def __init__(self,
bot: "IRCBot.Bot",
events: EventManager.EventHook,
exports: Exports.Exports,
timers: Timers.Timers):
self.bot = bot self.bot = bot
self.events = events self.events = events
self.exports = exports self.exports = exports
@ -32,7 +33,13 @@ class BaseModule(object):
pass pass
class ModuleManager(object): class ModuleManager(object):
def __init__(self, events, exports, timers, config, log, directory): def __init__(self,
events: EventManager.EventHook,
exports: Exports.Exports,
timers: Timers.Timers,
config: Config.Config,
log: Logging.Log,
directory: str):
self.events = events self.events = events
self.exports = exports self.exports = exports
self.config = config self.config = config
@ -43,23 +50,24 @@ class ModuleManager(object):
self.modules = {} self.modules = {}
self.waiting_requirement = {} self.waiting_requirement = {}
def list_modules(self): def list_modules(self) -> typing.List[str]:
return sorted(glob.glob(os.path.join(self.directory, "*.py"))) return sorted(glob.glob(os.path.join(self.directory, "*.py")))
def _module_name(self, path): def _module_name(self, path: str) -> str:
return os.path.basename(path).rsplit(".py", 1)[0].lower() return os.path.basename(path).rsplit(".py", 1)[0].lower()
def _module_path(self, name): def _module_path(self, name: str) -> str:
return os.path.join(self.directory, "%s.py" % name) return os.path.join(self.directory, "%s.py" % name)
def _import_name(self, name): def _import_name(self, name: str) -> str:
return "bitbot_%s" % name return "bitbot_%s" % name
def _get_magic(self, obj, magic, default): def _get_magic(self, obj: typing.Any, magic: str, default: typing.Any
) -> typing.Any:
return getattr(obj, magic) if hasattr(obj, magic) else default return getattr(obj, magic) if hasattr(obj, magic) else default
def _load_module(self, bot, name): def _load_module(self, bot: "IRCBot.Bot", name: str):
path = self._module_path(name) path = self._module_path(name)
for hashflag, value in utils.get_hashflags(path): for hashflag, value in utils.parse.hashflags(path):
if hashflag == "ignore": if hashflag == "ignore":
# nope, ignore this module. # nope, ignore this module.
raise ModuleNotLoadedWarning("module ignored") raise ModuleNotLoadedWarning("module ignored")
@ -97,10 +105,12 @@ class ModuleManager(object):
module_object._name = name.title() module_object._name = name.title()
for attribute_name in dir(module_object): for attribute_name in dir(module_object):
attribute = getattr(module_object, attribute_name) attribute = getattr(module_object, attribute_name)
for hook in self._get_magic(attribute, BITBOT_HOOKS_MAGIC, []): for hook in self._get_magic(attribute,
utils.consts.BITBOT_HOOKS_MAGIC, []):
context_events.on(hook["event"]).hook(attribute, context_events.on(hook["event"]).hook(attribute,
**hook["kwargs"]) **hook["kwargs"])
for export in self._get_magic(module_object, BITBOT_EXPORTS_MAGIC, []): for export in self._get_magic(module_object,
utils.consts.BITBOT_EXPORTS_MAGIC, []):
context_exports.add(export["setting"], export["value"]) context_exports.add(export["setting"], export["value"])
module_object._context = context module_object._context = context
@ -111,7 +121,7 @@ class ModuleManager(object):
"attempted to be used twice") "attempted to be used twice")
return module_object return module_object
def load_module(self, bot, name): def load_module(self, bot: "IRCBot.Bot", name: str):
try: try:
module = self._load_module(bot, name) module = self._load_module(bot, name)
except ModuleWarning as warning: except ModuleWarning as warning:
@ -128,7 +138,8 @@ class ModuleManager(object):
self.load_module(bot, requirement_name) self.load_module(bot, requirement_name)
self.log.info("Module '%s' loaded", [name]) self.log.info("Module '%s' loaded", [name])
def load_modules(self, bot, whitelist=[], blacklist=[]): def load_modules(self, bot: "IRCBot.Bot", whitelist: typing.List[str]=[],
blacklist: typing.List[str]=[]):
for path in self.list_modules(): for path in self.list_modules():
name = self._module_name(path) name = self._module_name(path)
if name in whitelist or (not whitelist and not name in blacklist): if name in whitelist or (not whitelist and not name in blacklist):
@ -137,7 +148,7 @@ class ModuleManager(object):
except ModuleWarning: except ModuleWarning:
pass pass
def unload_module(self, name): def unload_module(self, name: str):
if not name in self.modules: if not name in self.modules:
raise ModuleNotFoundException() raise ModuleNotFoundException()
module = self.modules[name] module = self.modules[name]

View file

@ -1,7 +1,9 @@
import socket, typing
class Socket(object): class Socket(object):
def __init__(self, socket, on_read, encoding="utf8"): def __init__(self, socket: socket.socket,
on_read: typing.Callable[["Socket", str], None],
encoding: str="utf8"):
self.socket = socket self.socket = socket
self._on_read = on_read self._on_read = on_read
self.encoding = encoding self.encoding = encoding
@ -12,18 +14,18 @@ class Socket(object):
self.length = None self.length = None
self.connected = True self.connected = True
def fileno(self): def fileno(self) -> int:
return self.socket.fileno() return self.socket.fileno()
def disconnect(self): def disconnect(self):
self.connected = False self.connected = False
def _decode(self, s): def _decode(self, s: bytes) -> str:
return s.decode(self.encoding) if self.encoding else s return s.decode(self.encoding)
def _encode(self, s): def _encode(self, s: str) -> bytes:
return s.encode(self.encoding) if self.encoding else s return s.encode(self.encoding)
def read(self): def read(self) -> typing.Optional[typing.List[str]]:
data = self.socket.recv(1024) data = self.socket.recv(1024)
if not data: if not data:
return None return None
@ -35,17 +37,17 @@ class Socket(object):
if data_split[-1]: if data_split[-1]:
self._read_buffer = data_split.pop(-1) self._read_buffer = data_split.pop(-1)
return [self._decode(data) for data in data_split] return [self._decode(data) for data in data_split]
return [data.decode(self.encoding)] return [self._decode(data)]
def parse_data(self, data): def parse_data(self, data: str):
self._on_read(self, data) self._on_read(self, data)
def send(self, data): def send(self, data: str):
self._write_buffer += self._encode(data) self._write_buffer += self._encode(data)
def _send(self): def _send(self):
self._write_buffer = self._write_buffer[self.socket.send( self._write_buffer = self._write_buffer[self.socket.send(
self._write_buffer):] self._write_buffer):]
def waiting_send(self): def waiting_send(self) -> bool:
return bool(len(self._write_buffer)) return bool(len(self._write_buffer))

View file

@ -1,7 +1,9 @@
import time, uuid import time, typing, uuid
from src import Database, EventManager, Logging
class Timer(object): class Timer(object):
def __init__(self, id, context, name, delay, next_due, kwargs): def __init__(self, id: int, context: str, name: str, delay: float,
next_due: float, kwargs: dict):
self.id = id self.id = id
self.context = context self.context = context
self.name = name self.name = name
@ -15,9 +17,9 @@ class Timer(object):
def set_next_due(self): def set_next_due(self):
self.next_due = time.time()+self.delay self.next_due = time.time()+self.delay
def due(self): def due(self) -> bool:
return self.time_left() <= 0 return self.time_left() <= 0
def time_left(self): def time_left(self) -> float:
return self.next_due-time.time() return self.next_due-time.time()
def redo(self): def redo(self):
@ -25,42 +27,33 @@ class Timer(object):
self.set_next_due() self.set_next_due()
def finish(self): def finish(self):
self._done = True self._done = True
def done(self): def done(self) -> bool:
return self._done return self._done
class TimersContext(object):
def __init__(self, parent, context):
self._parent = parent
self.context = context
def add(self, name, delay, next_due=None, **kwargs):
self._parent._add(self.context, name, delay, next_due, None, False,
kwargs)
def add_persistent(self, name, delay, next_due=None, **kwargs):
self._parent._add(None, name, delay, next_due, None, True,
kwargs)
class Timers(object): class Timers(object):
def __init__(self, database, events, log): def __init__(self, database: Database.Database,
events: EventManager.EventHook,
log: Logging.Log):
self.database = database self.database = database
self.events = events self.events = events
self.log = log self.log = log
self.timers = [] self.timers = []
self.context_timers = {} self.context_timers = {}
def new_context(self, context): def new_context(self, context: str) -> "TimersContext":
return TimersContext(self, context) return TimersContext(self, context)
def setup(self, timers): def setup(self, timers: typing.List[typing.Tuple[str, dict]]):
for name, timer in timers: for name, timer in timers:
id = name.split("timer-", 1)[1] id = name.split("timer-", 1)[1]
self._add(timer["name"], None, timer["delay"], timer[ self._add(timer["name"], None, timer["delay"], timer[
"next-due"], id, False, timer["kwargs"]) "next-due"], id, False, timer["kwargs"])
def _persist(self, timer): def _persist(self, timer: Timer):
self.database.bot_settings.set("timer-%s" % timer.id, { self.database.bot_settings.set("timer-%s" % timer.id, {
"name": timer.name, "delay": timer.delay, "name": timer.name, "delay": timer.delay,
"next-due": timer.next_due, "kwargs": timer.kwargs}) "next-due": timer.next_due, "kwargs": timer.kwargs})
def _remove(self, timer): def _remove(self, timer: Timer):
if timer.context: if timer.context:
self.context_timers[timer.context].remove(timer) self.context_timers[timer.context].remove(timer)
if not self.context_timers[timer.context]: if not self.context_timers[timer.context]:
@ -69,11 +62,13 @@ class Timers(object):
self.timers.remove(timer) self.timers.remove(timer)
self.database.bot_settings.delete("timer-%s" % timer.id) self.database.bot_settings.delete("timer-%s" % timer.id)
def add(self, name, delay, next_due=None, **kwargs): def add(self, name: str, delay: float, next_due: float=None, **kwargs):
self._add(None, name, delay, next_due, None, False, kwargs) self._add(None, name, delay, next_due, None, False, kwargs)
def add_persistent(self, name, delay, next_due=None, **kwargs): def add_persistent(self, name: str, delay: float, next_due: float=None,
**kwargs):
self._add(None, name, delay, next_due, None, True, kwargs) self._add(None, name, delay, next_due, None, True, kwargs)
def _add(self, context, name, delay, next_due, id, persist, kwargs): def _add(self, context: str, name: str, delay: float, next_due: float,
id: str, persist: bool, kwargs: dict):
id = id or uuid.uuid4().hex id = id or uuid.uuid4().hex
timer = Timer(id, context, name, delay, next_due, kwargs) timer = Timer(id, context, name, delay, next_due, kwargs)
if persist: if persist:
@ -86,13 +81,13 @@ class Timers(object):
else: else:
self.timers.append(timer) self.timers.append(timer)
def next(self): def next(self) -> float:
times = filter(None, [timer.time_left() for timer in self.get_timers()]) times = filter(None, [timer.time_left() for timer in self.get_timers()])
if not times: if not times:
return None return None
return max(min(times), 0) return max(min(times), 0)
def get_timers(self): def get_timers(self) -> typing.List[Timer]:
return self.timers + sum(self.context_timers.values(), []) return self.timers + sum(self.context_timers.values(), [])
def call(self): def call(self):
@ -104,6 +99,19 @@ class Timers(object):
if timer.done(): if timer.done():
self._remove(timer) self._remove(timer)
def purge_context(self, context): def purge_context(self, context: str):
if context in self.context_timers: if context in self.context_timers:
del self.context_timers[context] del self.context_timers[context]
class TimersContext(object):
def __init__(self, parent: Timers, context: str):
self._parent = parent
self.context = context
def add(self, name: str, delay: float, next_due: float=None,
**kwargs):
self._parent._add(self.context, name, delay, next_due, None, False,
kwargs)
def add_persistent(self, name: str, delay: float, next_due: float=None,
**kwargs):
self._parent._add(None, name, delay, next_due, None, True,
kwargs)

View file

@ -1,6 +1,5 @@
import decimal, io, re import decimal, io, re, typing
from src import ModuleManager from src.utils import consts, irc, http, parse
from . import irc, http
TIME_SECOND = 1 TIME_SECOND = 1
TIME_MINUTE = TIME_SECOND*60 TIME_MINUTE = TIME_SECOND*60
@ -8,7 +7,7 @@ TIME_HOUR = TIME_MINUTE*60
TIME_DAY = TIME_HOUR*24 TIME_DAY = TIME_HOUR*24
TIME_WEEK = TIME_DAY*7 TIME_WEEK = TIME_DAY*7
def time_unit(seconds): def time_unit(seconds: int) -> typing.Tuple[int, str]:
since = None since = None
unit = None unit = None
if seconds >= TIME_WEEK: if seconds >= TIME_WEEK:
@ -29,7 +28,7 @@ def time_unit(seconds):
since = int(since) since = int(since)
if since > 1: if since > 1:
unit = "%ss" % unit # pluralise the unit unit = "%ss" % unit # pluralise the unit
return [since, unit] return (since, unit)
REGEX_PRETTYTIME = re.compile("\d+[wdhms]", re.I) REGEX_PRETTYTIME = re.compile("\d+[wdhms]", re.I)
@ -38,7 +37,7 @@ SECONDS_HOURS = SECONDS_MINUTES*60
SECONDS_DAYS = SECONDS_HOURS*24 SECONDS_DAYS = SECONDS_HOURS*24
SECONDS_WEEKS = SECONDS_DAYS*7 SECONDS_WEEKS = SECONDS_DAYS*7
def from_pretty_time(pretty_time): def from_pretty_time(pretty_time: str) -> typing.Optional[int]:
seconds = 0 seconds = 0
for match in re.findall(REGEX_PRETTYTIME, pretty_time): for match in re.findall(REGEX_PRETTYTIME, pretty_time):
number, unit = int(match[:-1]), match[-1].lower() number, unit = int(match[:-1]), match[-1].lower()
@ -54,12 +53,14 @@ def from_pretty_time(pretty_time):
if seconds > 0: if seconds > 0:
return seconds return seconds
UNIT_MINIMUM = 6
UNIT_SECOND = 5 UNIT_SECOND = 5
UNIT_MINUTE = 4 UNIT_MINUTE = 4
UNIT_HOUR = 3 UNIT_HOUR = 3
UNIT_DAY = 2 UNIT_DAY = 2
UNIT_WEEK = 1 UNIT_WEEK = 1
def to_pretty_time(total_seconds, minimum_unit=UNIT_SECOND, max_units=6): def to_pretty_time(total_seconds: int, minimum_unit: int=UNIT_SECOND,
max_units: int=UNIT_MINIMUM) -> str:
minutes, seconds = divmod(total_seconds, 60) minutes, seconds = divmod(total_seconds, 60)
hours, minutes = divmod(minutes, 60) hours, minutes = divmod(minutes, 60)
days, hours = divmod(hours, 24) days, hours = divmod(hours, 24)
@ -84,7 +85,7 @@ def to_pretty_time(total_seconds, minimum_unit=UNIT_SECOND, max_units=6):
units += 1 units += 1
return out return out
def parse_number(s): def parse_number(s: str) -> str:
try: try:
decimal.Decimal(s) decimal.Decimal(s)
return s return s
@ -110,28 +111,18 @@ def parse_number(s):
IS_TRUE = ["true", "yes", "on", "y"] IS_TRUE = ["true", "yes", "on", "y"]
IS_FALSE = ["false", "no", "off", "n"] IS_FALSE = ["false", "no", "off", "n"]
def bool_or_none(s): def bool_or_none(s: str) -> typing.Optional[bool]:
s = s.lower() s = s.lower()
if s in IS_TRUE: if s in IS_TRUE:
return True return True
elif s in IS_FALSE: elif s in IS_FALSE:
return False return False
def int_or_none(s): def int_or_none(s: str) -> typing.Optional[int]:
stripped_s = s.lstrip("0") stripped_s = s.lstrip("0")
if stripped_s.isdigit(): if stripped_s.isdigit():
return int(stripped_s) return int(stripped_s)
def get_closest_setting(event, setting, default=None): def prevent_highlight(nickname: str) -> str:
server = event["server"]
if "channel" in event:
closest = event["channel"]
elif "target" in event and "is_channel" in event and event["is_channel"]:
closest = event["target"]
else:
closest = event["user"]
return closest.get_setting(setting, server.get_setting(setting, default))
def prevent_highlight(nickname):
return nickname[0]+"\u200c"+nickname[1:] return nickname[0]+"\u200c"+nickname[1:]
class EventError(Exception): class EventError(Exception):
@ -139,80 +130,34 @@ class EventError(Exception):
class EventsResultsError(EventError): class EventsResultsError(EventError):
def __init__(self): def __init__(self):
EventError.__init__(self, "Failed to load results") EventError.__init__(self, "Failed to load results")
class EventsNotEnoughArgsError(EventError):
def __init__(self, n):
EventError.__init__(self, "Not enough arguments (minimum %d)" % n)
class EventsUsageError(EventError):
def __init__(self, usage):
EventError.__init__(self, "Not enough arguments, usage: %s" % usage)
def _set_get_append(obj, setting, item): def _set_get_append(obj: typing.Any, setting: str, item: typing.Any):
if not hasattr(obj, setting): if not hasattr(obj, setting):
setattr(obj, setting, []) setattr(obj, setting, [])
getattr(obj, setting).append(item) getattr(obj, setting).append(item)
def hook(event, **kwargs): def hook(event: str, **kwargs):
def _hook_func(func): def _hook_func(func):
_set_get_append(func, ModuleManager.BITBOT_HOOKS_MAGIC, _set_get_append(func, consts.BITBOT_HOOKS_MAGIC,
{"event": event, "kwargs": kwargs}) {"event": event, "kwargs": kwargs})
return func return func
return _hook_func return _hook_func
def export(setting, value): def export(setting: str, value: typing.Any):
def _export_func(module): def _export_func(module):
_set_get_append(module, ModuleManager.BITBOT_EXPORTS_MAGIC, _set_get_append(module, consts.BITBOT_EXPORTS_MAGIC,
{"setting": setting, "value": value}) {"setting": setting, "value": value})
return module return module
return _export_func return _export_func
COMMENT_TYPES = ["#", "//"] TOP_10_CALLABLE = typing.Callable[[typing.Any], typing.Any]
def get_hashflags(filename): def top_10(items: typing.List[typing.Any],
hashflags = {} convert_key: TOP_10_CALLABLE=lambda x: x,
with io.open(filename, mode="r", encoding="utf8") as f: value_format: TOP_10_CALLABLE=lambda x: x):
for line in f:
line = line.strip("\n")
found = False
for comment_type in COMMENT_TYPES:
if line.startswith(comment_type):
line = line.replace(comment_type, "", 1).lstrip()
found = True
break
if not found:
break
elif line.startswith("--"):
hashflag, sep, value = line[2:].partition(" ")
hashflags[hashflag] = value if sep else None
return hashflags.items()
class Docstring(object):
def __init__(self, description, items, var_items):
self.description = description
self.items = items
self.var_items = var_items
def parse_docstring(s):
description = ""
last_item = None
items = {}
var_items = {}
if s:
for line in s.split("\n"):
line = line.strip()
if line:
if line[0] == ":":
key, _, value = line[1:].partition(": ")
last_item = key
if key in var_items:
var_items[key].append(value)
elif key in items:
var_items[key] = [items.pop(key), value]
else:
items[key] = value
else:
if last_item:
items[last_item] += " %s" % line
else:
if description:
description += " "
description += line
return Docstring(description, items, var_items)
def top_10(items, convert_key=lambda x: x, value_format=lambda x: x):
top_10 = sorted(items.keys()) top_10 = sorted(items.keys())
top_10 = sorted(top_10, key=items.get, reverse=True)[:10] top_10 = sorted(top_10, key=items.get, reverse=True)[:10]

2
src/utils/consts.py Normal file
View file

@ -0,0 +1,2 @@
BITBOT_HOOKS_MAGIC = "__bitbot_hooks"
BITBOT_EXPORTS_MAGIC = "__bitbot_exports"

View file

@ -1,4 +1,4 @@
import re, signal, traceback, urllib.error, urllib.parse import re, signal, traceback, typing, urllib.error, urllib.parse
import json as _json import json as _json
import bs4, requests import bs4, requests
@ -18,9 +18,10 @@ class HTTPParsingException(HTTPException):
def throw_timeout(): def throw_timeout():
raise HTTPTimeoutException() raise HTTPTimeoutException()
def get_url(url, method="GET", get_params={}, post_data=None, headers={}, def get_url(url: str, method: str="GET", get_params: dict={},
json_data=None, code=False, json=False, soup=False, parser="lxml", post_data: typing.Any=None, headers: dict={},
fallback_encoding="utf8"): json_data: typing.Any=None, code: bool=False, json: bool=False,
soup: bool=False, parser: str="lxml", fallback_encoding: str="utf8"):
if not urllib.parse.urlparse(url).scheme: if not urllib.parse.urlparse(url).scheme:
url = "http://%s" % url url = "http://%s" % url
@ -66,6 +67,6 @@ def get_url(url, method="GET", get_params={}, post_data=None, headers={},
else: else:
return data return data
def strip_html(s): def strip_html(s: str) -> str:
return bs4.BeautifulSoup(s, "lxml").get_text() return bs4.BeautifulSoup(s, "lxml").get_text()

View file

@ -1,4 +1,4 @@
import string, re import string, re, typing
ASCII_UPPER = string.ascii_uppercase ASCII_UPPER = string.ascii_uppercase
ASCII_LOWER = string.ascii_lowercase ASCII_LOWER = string.ascii_lowercase
@ -7,32 +7,36 @@ STRICT_RFC1459_LOWER = ASCII_LOWER+r'|{}'
RFC1459_UPPER = STRICT_RFC1459_UPPER+"^" RFC1459_UPPER = STRICT_RFC1459_UPPER+"^"
RFC1459_LOWER = STRICT_RFC1459_LOWER+"~" RFC1459_LOWER = STRICT_RFC1459_LOWER+"~"
def remove_colon(s): def remove_colon(s: str) -> str:
if s.startswith(":"): if s.startswith(":"):
s = s[1:] s = s[1:]
return s return s
MULTI_REPLACE_ITERABLE = typing.Iterable[str]
# case mapping lowercase/uppcase logic # case mapping lowercase/uppcase logic
def _multi_replace(s, chars1, chars2): def _multi_replace(s: str,
chars1: typing.Iterable[str],
chars2: typing.Iterable[str]) -> str:
for char1, char2 in zip(chars1, chars2): for char1, char2 in zip(chars1, chars2):
s = s.replace(char1, char2) s = s.replace(char1, char2)
return s return s
def lower(server, s): def lower(case_mapping: str, s: str) -> str:
if server.case_mapping == "ascii": if case_mapping == "ascii":
return _multi_replace(s, ASCII_UPPER, ASCII_LOWER) return _multi_replace(s, ASCII_UPPER, ASCII_LOWER)
elif server.case_mapping == "rfc1459": elif case_mapping == "rfc1459":
return _multi_replace(s, RFC1459_UPPER, RFC1459_LOWER) return _multi_replace(s, RFC1459_UPPER, RFC1459_LOWER)
elif server.case_mapping == "strict-rfc1459": elif case_mapping == "strict-rfc1459":
return _multi_replace(s, STRICT_RFC1459_UPPER, STRICT_RFC1459_LOWER) return _multi_replace(s, STRICT_RFC1459_UPPER, STRICT_RFC1459_LOWER)
else: else:
raise ValueError("unknown casemapping '%s'" % server.case_mapping) raise ValueError("unknown casemapping '%s'" % case_mapping)
# compare a string while respecting case mapping # compare a string while respecting case mapping
def equals(server, s1, s2): def equals(case_mapping: str, s1: str, s2: str) -> bool:
return lower(server, s1) == lower(server, s2) return lower(case_mapping, s1) == lower(case_mapping, s2)
class IRCHostmask(object): class IRCHostmask(object):
def __init__(self, nickname, username, hostname, hostmask): def __init__(self, nickname: str, username: str, hostname: str,
hostmask: str):
self.nickname = nickname self.nickname = nickname
self.username = username self.username = username
self.hostname = hostname self.hostname = hostname
@ -42,24 +46,24 @@ class IRCHostmask(object):
def __str__(self): def __str__(self):
return self.hostmask return self.hostmask
def seperate_hostmask(hostmask): def seperate_hostmask(hostmask: str) -> IRCHostmask:
hostmask = remove_colon(hostmask) hostmask = remove_colon(hostmask)
nickname, _, username = hostmask.partition("!") nickname, _, username = hostmask.partition("!")
username, _, hostname = username.partition("@") username, _, hostname = username.partition("@")
return IRCHostmask(nickname, username, hostname, hostmask) return IRCHostmask(nickname, username, hostname, hostmask)
class IRCLine(object): class IRCLine(object):
def __init__(self, tags, prefix, command, args, arbitrary, last, server): def __init__(self, tags: dict, prefix: str, command: str,
args: typing.List[str], arbitrary: typing.Optional[str],
last: str):
self.tags = tags self.tags = tags
self.prefix = prefix self.prefix = prefix
self.command = command self.command = command
self.args = args self.args = args
self.arbitrary = arbitrary self.arbitrary = arbitrary
self.last = last self.last = last
self.server = server
def parse_line(server, line): def parse_line(line: str) -> IRCLine:
tags = {} tags = {}
prefix = None prefix = None
command = None command = None
@ -81,7 +85,7 @@ def parse_line(server, line):
args = line.split(" ") args = line.split(" ")
last = arbitrary or args[-1] last = arbitrary or args[-1]
return IRCLine(tags, prefix, command, args, arbitrary, last, server) return IRCLine(tags, prefix, command, args, arbitrary, last)
COLOR_WHITE, COLOR_BLACK, COLOR_BLUE, COLOR_GREEN = 0, 1, 2, 3 COLOR_WHITE, COLOR_BLACK, COLOR_BLUE, COLOR_GREEN = 0, 1, 2, 3
COLOR_RED, COLOR_BROWN, COLOR_PURPLE, COLOR_ORANGE = 4, 5, 6, 7 COLOR_RED, COLOR_BROWN, COLOR_PURPLE, COLOR_ORANGE = 4, 5, 6, 7
@ -94,20 +98,20 @@ FONT_BOLD, FONT_ITALIC, FONT_UNDERLINE, FONT_INVERT = ("\x02", "\x1D",
FONT_COLOR, FONT_RESET = "\x03", "\x0F" FONT_COLOR, FONT_RESET = "\x03", "\x0F"
REGEX_COLOR = re.compile("%s\d\d(?:,\d\d)?" % FONT_COLOR) REGEX_COLOR = re.compile("%s\d\d(?:,\d\d)?" % FONT_COLOR)
def color(s, foreground, background=None): def color(s: str, foreground: str, background: str=None) -> str:
foreground = str(foreground).zfill(2) foreground = str(foreground).zfill(2)
if background: if background:
background = str(background).zfill(2) background = str(background).zfill(2)
return "%s%s%s%s%s" % (FONT_COLOR, foreground, return "%s%s%s%s%s" % (FONT_COLOR, foreground,
"" if not background else ",%s" % background, s, FONT_COLOR) "" if not background else ",%s" % background, s, FONT_COLOR)
def bold(s): def bold(s: str) -> str:
return "%s%s%s" % (FONT_BOLD, s, FONT_BOLD) return "%s%s%s" % (FONT_BOLD, s, FONT_BOLD)
def underline(s): def underline(s: str) -> str:
return "%s%s%s" % (FONT_UNDERLINE, s, FONT_UNDERLINE) return "%s%s%s" % (FONT_UNDERLINE, s, FONT_UNDERLINE)
def strip_font(s): def strip_font(s: str) -> str:
s = s.replace(FONT_BOLD, "") s = s.replace(FONT_BOLD, "")
s = s.replace(FONT_ITALIC, "") s = s.replace(FONT_ITALIC, "")
s = REGEX_COLOR.sub("", s) s = REGEX_COLOR.sub("", s)

57
src/utils/parse.py Normal file
View file

@ -0,0 +1,57 @@
import io, typing
COMMENT_TYPES = ["#", "//"]
def hashflags(filename: str) -> typing.List[typing.Tuple[str, str]]:
hashflags = {}
with io.open(filename, mode="r", encoding="utf8") as f:
for line in f:
line = line.strip("\n")
found = False
for comment_type in COMMENT_TYPES:
if line.startswith(comment_type):
line = line.replace(comment_type, "", 1).lstrip()
found = True
break
if not found:
break
elif line.startswith("--"):
hashflag, sep, value = line[2:].partition(" ")
hashflags[hashflag] = value if sep else None
return list(hashflags.items())
class Docstring(object):
def __init__(self, description: str, items: dict, var_items: dict):
self.description = description
self.items = items
self.var_items = var_items
def docstring(s: str) -> Docstring:
description = ""
last_item = None
items = {}
var_items = {}
if s:
for line in s.split("\n"):
line = line.strip()
if line:
if line[0] == ":":
key, _, value = line[1:].partition(": ")
last_item = key
if key in var_items:
var_items[key].append(value)
elif key in items:
var_items[key] = [items.pop(key), value]
else:
items[key] = value
else:
if last_item:
items[last_item] += " %s" % line
else:
if description:
description += " "
description += line
return Docstring(description, items, var_items)