Fix/refactor issues brought up by type hint linting

This commit is contained in:
jesopo 2018-10-30 17:49:35 +00:00
parent d0e3574227
commit b543e31cd2
13 changed files with 73 additions and 59 deletions

View file

@ -8,14 +8,14 @@ class Cache(object):
def cache(self, item: typing.Any) -> str: def cache(self, item: typing.Any) -> str:
return self._cache(item, None) return self._cache(item, None)
def temporary_cache(self, item: typing.Any, timeout: float)-> str: def temporary_cache(self, item: typing.Any, timeout: float)-> str:
return self._cache(item, timeout) return self._cache(item, time.monotonic()+timeout)
def _cache(self, item: typing.Any, timeout: float) -> str: def _cache(self, item: typing.Any, timeout: typing.Optional[float]) -> str:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
self._items[id] = [item, time.monotonic()+timeout] self._items[id] = [item, timeout]
self._item_to_id[item] = id self._item_to_id[item] = id
return id return id
def next_expiration(self) -> float: def next_expiration(self) -> typing.Optional[float]:
expirations = [self._items[id][1] for id in self._items] expirations = [self._items[id][1] for id in self._items]
expirations = list(filter(None, expirations)) expirations = list(filter(None, expirations))
if not expirations: if not expirations:

View file

@ -32,7 +32,7 @@ class EventCallback(object):
self.function = function self.function = function
self.priority = priority self.priority = priority
self.kwargs = kwargs self.kwargs = kwargs
self.docstring = utils.parse.docstring(function.__doc__) self.docstring = utils.parse.docstring(function.__doc__ or "")
def call(self, event: Event) -> typing.Any: def call(self, event: Event) -> typing.Any:
return self.function(event) return self.function(event)
@ -49,6 +49,7 @@ class EventHook(object):
self.parent = parent self.parent = parent
self._children = {} self._children = {}
self._hooks = [] self._hooks = []
self._replayed = False
self._stored_events = [] self._stored_events = []
self._context_hooks = {} self._context_hooks = {}
@ -57,10 +58,14 @@ class EventHook(object):
def _get_path(self) -> str: def _get_path(self) -> str:
path = [] path = []
parent = self parent = self # type: typing.Optional[EventHook]
while not parent == None and not parent.name == None: while not parent == None:
path.append(parent.name) cast_parent = typing.cast(EventHook, parent)
parent = parent.parent if cast_parent.name == None:
break
path.append(typing.cast(str, cast_parent.name))
parent = cast_parent.parent
return DEFAULT_EVENT_DELIMITER.join(path[::-1]) return DEFAULT_EVENT_DELIMITER.join(path[::-1])
def new_context(self, context: str) -> "EventHookContext": def new_context(self, context: str) -> "EventHookContext":
@ -72,8 +77,8 @@ class EventHook(object):
def _context_hook(self, context: str, function: CALLBACK_TYPE, def _context_hook(self, context: str, function: CALLBACK_TYPE,
priority: int, replay: bool, kwargs: dict) -> EventCallback: priority: int, replay: bool, kwargs: dict) -> EventCallback:
return self._hook(function, context, priority, replay, kwargs) return self._hook(function, context, priority, replay, kwargs)
def _hook(self, function: CALLBACK_TYPE, context: str, priority: int, def _hook(self, function: CALLBACK_TYPE, context: typing.Optional[str],
replay: bool, kwargs: dict) -> EventCallback: priority: int, replay: bool, kwargs: dict) -> EventCallback:
callback = EventCallback(function, priority, kwargs) callback = EventCallback(function, priority, kwargs)
if context == None: if context == None:
@ -83,10 +88,10 @@ class EventHook(object):
self._context_hooks[context] = [] self._context_hooks[context] = []
self._context_hooks[context].append(callback) self._context_hooks[context].append(callback)
if replay and not self._stored_events == None: if replay and not self._replayed:
for kwargs in self._stored_events: for kwargs in self._stored_events:
self._call(kwargs, True, None) self._call(kwargs, True, None)
self._stored_events = None self._replayed = True
return callback return callback
def unhook(self, callback: "EventHook"): def unhook(self, callback: "EventHook"):
@ -102,58 +107,65 @@ class EventHook(object):
for context in empty: for context in empty:
del self._context_hooks[context] del self._context_hooks[context]
def _make_multiple_hook(self, source: "EventHook", context: str, def _make_multiple_hook(self, source: "EventHook",
events: typing.List[str]) -> "MultipleEventHook": context: typing.Optional[str],
events: typing.Iterable[str]) -> "MultipleEventHook":
multiple_event_hook = MultipleEventHook() multiple_event_hook = MultipleEventHook()
for event in events: for event in events:
event_hook = source.get_child(event) event_hook = source.get_child(event)
if not context == None: if not context == None:
event_hook = event_hook.new_context(context) context_hook = event_hook.new_context(typing.cast(str, context))
multiple_event_hook._add(event_hook) multiple_event_hook._add(typing.cast(EventHook, context_hook))
else:
multiple_event_hook._add(event_hook)
return multiple_event_hook return multiple_event_hook
def on(self, subevent: str, *extra_subevents, def on(self, subevent: str, *extra_subevents: str,
delimiter: int = DEFAULT_EVENT_DELIMITER) -> "EventHook": delimiter: str = DEFAULT_EVENT_DELIMITER) -> "EventHook":
return self._on(subevent, extra_subevents, None, delimiter) return self._on(subevent, extra_subevents, None, delimiter)
def _context_on(self, context: str, subevent: str, def _context_on(self, context: str, subevent: str,
extra_subevents: typing.List[str], extra_subevents: typing.Tuple[str, ...],
delimiter: str = DEFAULT_EVENT_DELIMITER) -> "EventHook": delimiter: str = DEFAULT_EVENT_DELIMITER) -> "EventHook":
return self._on(subevent, extra_subevents, context, delimiter) return self._on(subevent, extra_subevents, context, delimiter)
def _on(self, subevent: str, extra_subevents: typing.List[str], def _on(self, subevent: str, extra_subevents: typing.Tuple[str, ...],
context: str, delimiter: str) -> "EventHook": context: typing.Optional[str], delimiter: str) -> "EventHook":
if delimiter in subevent: if delimiter in subevent:
event_chain = subevent.split(delimiter) event_chain = subevent.split(delimiter)
event_obj = self event_obj = self
for event_name in event_chain: for event_name in event_chain:
if DEFAULT_MULTI_DELIMITER in event_name: if DEFAULT_MULTI_DELIMITER in event_name:
return self._make_multiple_hook(event_obj, context, multiple_hook = self._make_multiple_hook(event_obj, context,
event_name.split(DEFAULT_MULTI_DELIMITER)) event_name.split(DEFAULT_MULTI_DELIMITER))
return typing.cast(EventHook, multiple_hook)
event_obj = event_obj.get_child(event_name) event_obj = event_obj.get_child(event_name)
if not context == None: if not context == None:
return event_obj.new_context(context) context_hook = event_obj.new_context(typing.cast(str, context))
return typing.cast(EventHook, context_hook)
return event_obj return event_obj
if extra_subevents: if extra_subevents:
return self._make_multiple_hook(self, context, multiple_hook = self._make_multiple_hook(self, context,
(subevent,)+extra_subevents) (subevent,)+extra_subevents)
return typing.cast(EventHook, multiple_hook)
child = self.get_child(subevent) child = self.get_child(subevent)
if not context == None: if not context == None:
child = child.new_context(context) context_child = child.new_context(typing.cast(str, context))
child = typing.cast(EventHook, context_child)
return child return child
def call_for_result(self, default=None, **kwargs) -> typing.Any: def call_for_result(self, default=None, **kwargs) -> typing.Any:
return (self.call_limited(1, **kwargs) or [default])[0] return (self.call_limited(1, **kwargs) or [default])[0]
def assure_call(self, **kwargs): def assure_call(self, **kwargs):
if not self._stored_events == None: if not self._replayed:
self._stored_events.append(kwargs) self._stored_events.append(kwargs)
else: else:
self._call(kwargs, True, None) self._call(kwargs, True, None)
def call(self, **kwargs) -> typing.List[typing.Any]: def call(self, **kwargs) -> typing.List[typing.Any]:
return self._call(kwargs, True, None) return self._call(kwargs, True, None)
def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]: def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]:
return self._call(kwargs, True, None) return self._call(kwargs, True, maximum)
def call_unsafe_for_result(self, default=None, **kwargs) -> typing.Any: def call_unsafe_for_result(self, default=None, **kwargs) -> typing.Any:
return (self.call_unsafe_limited(1, **kwargs) or [default])[0] return (self.call_unsafe_limited(1, **kwargs) or [default])[0]
@ -163,7 +175,7 @@ class EventHook(object):
) -> typing.List[typing.Any]: ) -> typing.List[typing.Any]:
return self._call(kwargs, False, maximum) return self._call(kwargs, False, maximum)
def _call(self, kwargs: dict, safe: bool, maximum: int def _call(self, kwargs: dict, safe: bool, maximum: typing.Optional[int]
) -> typing.List[typing.Any]: ) -> typing.List[typing.Any]:
event_path = self._get_path() event_path = self._get_path()
self.log.trace("calling event: \"%s\" (params: %s)", self.log.trace("calling event: \"%s\" (params: %s)",
@ -223,10 +235,10 @@ class EventHook(object):
def get_hooks(self) -> typing.List[EventCallback]: def get_hooks(self) -> typing.List[EventCallback]:
return sorted(self._hooks + sum(self._context_hooks.values(), []), return sorted(self._hooks + sum(self._context_hooks.values(), []),
key=lambda e: e.priority) key=lambda e: e.priority)
def get_children(self) -> typing.List["EventHook"]: def get_children(self) -> typing.List[str]:
return list(self._children.keys()) return list(self._children.keys())
def is_empty(self) -> bool: def is_empty(self) -> bool:
return len(self.get_hooks() + self.get_children()) == 0 return (len(self.get_hooks())+len(self.get_children())) == 0
class MultipleEventHook(object): class MultipleEventHook(object):
def __init__(self): def __init__(self):

View file

@ -12,7 +12,8 @@ class Exports(object):
self._add(None, setting, value) self._add(None, setting, value)
def _context_add(self, context: str, setting: str, value: typing.Any): def _context_add(self, context: str, setting: str, value: typing.Any):
self._add(context, setting, value) self._add(context, setting, value)
def _add(self, context: str, setting: str, value: typing.Any): def _add(self, context: typing.Optional[str], setting: str,
value: typing.Any):
if context == None: if context == None:
if not setting in self._exports: if not setting in self._exports:
self._exports[setting] = [] self._exports[setting] = []

View file

@ -36,7 +36,7 @@ class Bot(object):
self.lock.release() self.lock.release()
def add_server(self, server_id: int, connect: bool = True def add_server(self, server_id: int, connect: bool = True
) -> typing.Optional[IRCServer.Server]: ) -> IRCServer.Server:
(_, alias, hostname, port, password, ipv4, tls, bindhost, nickname, (_, alias, hostname, port, password, ipv4, tls, bindhost, nickname,
username, realname) = self.database.servers.get(server_id) username, realname) = self.database.servers.get(server_id)
@ -44,7 +44,7 @@ class Bot(object):
hostname, port, password, ipv4, tls, bindhost, nickname, username, hostname, port, password, ipv4, tls, bindhost, nickname, username,
realname) realname)
if not new_server.get_setting("connect", True): if not new_server.get_setting("connect", True):
return return new_server
self._events.on("new.server").call(server=new_server) self._events.on("new.server").call(server=new_server)
if connect and new_server.get_setting("connect", True): if connect and new_server.get_setting("connect", True):
self.connect(new_server) self.connect(new_server)

View file

@ -1,5 +1,5 @@
import re, typing import re, typing
from src import IRCBot, utils from src import IRCBot, IRCServer, utils
class BufferLine(object): class BufferLine(object):
def __init__(self, sender: str, message: str, action: bool, tags: dict, def __init__(self, sender: str, message: str, action: bool, tags: dict,

View file

@ -9,9 +9,9 @@ class Channel(IRCObject.Object):
self.server = server self.server = server
self.bot = bot self.bot = bot
self.topic = "" self.topic = ""
self.topic_setter_nickname = None self.topic_setter_nickname = None # type: typing.Optional[str]
self.topic_setter_username = None self.topic_setter_username = None # type: typing.Optional[str]
self.topic_setter_hostname = None self.topic_setter_hostname = None # type: typing.Optional[str]
self.topic_time = 0 self.topic_time = 0
self.users = set([]) self.users = set([])
self.modes = {} self.modes = {}

View file

@ -160,7 +160,7 @@ class Server(IRCObject.Object):
self.nickname = nickname self.nickname = nickname
self.nickname_lower = utils.irc.lower(self.case_mapping, nickname) self.nickname_lower = utils.irc.lower(self.case_mapping, nickname)
def is_own_nickname(self, nickname: str): def is_own_nickname(self, nickname: str):
return utils.irc.equals(self, nickname, self.nickname) return utils.irc.equals(self.case_mapping, nickname, self.nickname)
def add_own_mode(self, mode: str, arg: str=None): def add_own_mode(self, mode: str, arg: str=None):
self.own_modes[mode] = arg self.own_modes[mode] = arg

View file

@ -29,7 +29,7 @@ class User(IRCObject.Object):
def get_id(self)-> int: def get_id(self)-> int:
return (self.identified_account_id_override or return (self.identified_account_id_override or
self.identified_account_id or self._id) self.identified_account_id or self._id)
def get_identified_account(self) -> str: def get_identified_account(self) -> typing.Optional[str]:
return (self.identified_account_override or self.identified_account) return (self.identified_account_override or self.identified_account)
def set_nickname(self, nickname: str): def set_nickname(self, nickname: str):

View file

@ -33,7 +33,7 @@ class Socket(object):
data = self._read_buffer+data data = self._read_buffer+data
self._read_buffer = b"" self._read_buffer = b""
if not self.delimiter == None: if not self.delimiter == None:
data_split = data.split(delimiter) data_split = data.split(self.delimiter)
if data_split[-1]: if data_split[-1]:
self._read_buffer = data_split.pop(-1) self._read_buffer = data_split.pop(-1)
return [self._decode(data) for data in data_split] return [self._decode(data) for data in data_split]

View file

@ -2,8 +2,8 @@ import time, typing, uuid
from src import Database, EventManager, Logging from src import Database, EventManager, Logging
class Timer(object): class Timer(object):
def __init__(self, id: int, context: str, name: str, delay: float, def __init__(self, id: str, context: typing.Optional[str], name: str,
next_due: float, kwargs: dict): delay: float, next_due: typing.Optional[float], kwargs: dict):
self.id = id self.id = id
self.context = context self.context = context
self.name = name self.name = name
@ -46,7 +46,7 @@ class Timers(object):
def setup(self, timers: typing.List[typing.Tuple[str, dict]]): def setup(self, timers: typing.List[typing.Tuple[str, dict]]):
for name, timer in timers: for name, timer in timers:
id = name.split("timer-", 1)[1] id = name.split("timer-", 1)[1]
self._add(timer["name"], None, timer["delay"], timer[ self._add(None, timer["name"], timer["delay"], timer[
"next-due"], id, False, timer["kwargs"]) "next-due"], id, False, timer["kwargs"])
def _persist(self, timer: Timer): def _persist(self, timer: Timer):
@ -67,9 +67,10 @@ class Timers(object):
def add_persistent(self, name: str, delay: float, next_due: float=None, def add_persistent(self, name: str, delay: float, next_due: float=None,
**kwargs): **kwargs):
self._add(None, name, delay, next_due, None, True, kwargs) self._add(None, name, delay, next_due, None, True, kwargs)
def _add(self, context: str, name: str, delay: float, next_due: float, def _add(self, context: typing.Optional[str], name: str, delay: float,
id: str, persist: bool, kwargs: dict): next_due: typing.Optional[float], id: typing.Optional[str],
id = id or uuid.uuid4().hex persist: bool, kwargs: dict):
id = id or str(uuid.uuid4())
timer = Timer(id, context, name, delay, next_due, kwargs) timer = Timer(id, context, name, delay, next_due, kwargs)
if persist: if persist:
self._persist(timer) self._persist(timer)
@ -81,7 +82,7 @@ class Timers(object):
else: else:
self.timers.append(timer) self.timers.append(timer)
def next(self) -> float: def next(self) -> typing.Optional[float]:
times = filter(None, [timer.time_left() for timer in self.get_timers()]) times = filter(None, [timer.time_left() for timer in self.get_timers()])
if not times: if not times:
return None return None

View file

@ -93,9 +93,9 @@ def parse_number(s: str) -> str:
pass pass
unit = s[-1].lower() unit = s[-1].lower()
number = s[:-1] number_str = s[:-1]
try: try:
number = decimal.Decimal(number) number = decimal.Decimal(number_str)
except: except:
raise ValueError("Invalid format '%s' passed to parse_number" % number) raise ValueError("Invalid format '%s' passed to parse_number" % number)
@ -155,7 +155,7 @@ def export(setting: str, value: typing.Any):
return _export_func return _export_func
TOP_10_CALLABLE = typing.Callable[[typing.Any], typing.Any] TOP_10_CALLABLE = typing.Callable[[typing.Any], typing.Any]
def top_10(items: typing.List[typing.Any], def top_10(items: typing.Dict[typing.Any, typing.Any],
convert_key: TOP_10_CALLABLE=lambda x: x, convert_key: TOP_10_CALLABLE=lambda x: x,
value_format: TOP_10_CALLABLE=lambda x: x): value_format: TOP_10_CALLABLE=lambda x: x):
top_10 = sorted(items.keys()) top_10 = sorted(items.keys())

View file

@ -8,7 +8,7 @@ REGEX_HTTP = re.compile("https?://", re.I)
RESPONSE_MAX = (1024*1024)*100 RESPONSE_MAX = (1024*1024)*100
class HTTPException: class HTTPException(Exception):
pass pass
class HTTPTimeoutException(HTTPException): class HTTPTimeoutException(HTTPException):
pass pass
@ -52,7 +52,7 @@ def get_url(url: str, method: str="GET", get_params: dict={},
if soup: if soup:
soup = bs4.BeautifulSoup(response_content, parser) soup = bs4.BeautifulSoup(response_content, parser)
if code: if code:
return response.code, soup return response.status_code, soup
return soup return soup
data = response_content.decode(response.encoding or fallback_encoding) data = response_content.decode(response.encoding or fallback_encoding)

View file

@ -53,7 +53,7 @@ def seperate_hostmask(hostmask: str) -> IRCHostmask:
return IRCHostmask(nickname, username, hostname, hostmask) return IRCHostmask(nickname, username, hostname, hostmask)
class IRCLine(object): class IRCLine(object):
def __init__(self, tags: dict, prefix: str, command: str, def __init__(self, tags: dict, prefix: typing.Optional[str], command: str,
args: typing.List[str], arbitrary: typing.Optional[str], args: typing.List[str], arbitrary: typing.Optional[str],
last: str): last: str):
self.tags = tags self.tags = tags
@ -65,7 +65,7 @@ class IRCLine(object):
def parse_line(line: str) -> IRCLine: def parse_line(line: str) -> IRCLine:
tags = {} tags = {}
prefix = None prefix = typing.Optional[IRCHostmask]
command = None command = None
if line[0] == "@": if line[0] == "@":
@ -74,12 +74,12 @@ def parse_line(line: str) -> IRCLine:
tag, _, value = tag.partition("=") tag, _, value = tag.partition("=")
tags[tag] = value tags[tag] = value
line, _, arbitrary = line.partition(" :") line, _, arbitrary_split = line.partition(" :")
arbitrary = arbitrary or None arbitrary = arbitrary_split or None
if line[0] == ":": if line[0] == ":":
prefix, line = line[1:].split(" ", 1) prefix_str, line = line[1:].split(" ", 1)
prefix = seperate_hostmask(prefix) prefix = seperate_hostmask(prefix_str)
command, _, line = line.partition(" ") command, _, line = line.partition(" ")
args = line.split(" ") args = line.split(" ")