Fix/refactor issues brought up by type hint linting

This commit is contained in:
jesopo 2018-10-30 17:49:35 +00:00
parent d0e3574227
commit b543e31cd2
13 changed files with 73 additions and 59 deletions

View file

@ -8,14 +8,14 @@ class Cache(object):
def cache(self, item: typing.Any) -> str:
return self._cache(item, None)
def temporary_cache(self, item: typing.Any, timeout: float)-> str:
return self._cache(item, timeout)
def _cache(self, item: typing.Any, timeout: float) -> str:
return self._cache(item, time.monotonic()+timeout)
def _cache(self, item: typing.Any, timeout: typing.Optional[float]) -> str:
id = str(uuid.uuid4())
self._items[id] = [item, time.monotonic()+timeout]
self._items[id] = [item, timeout]
self._item_to_id[item] = id
return id
def next_expiration(self) -> float:
def next_expiration(self) -> typing.Optional[float]:
expirations = [self._items[id][1] for id in self._items]
expirations = list(filter(None, expirations))
if not expirations:

View file

@ -32,7 +32,7 @@ class EventCallback(object):
self.function = function
self.priority = priority
self.kwargs = kwargs
self.docstring = utils.parse.docstring(function.__doc__)
self.docstring = utils.parse.docstring(function.__doc__ or "")
def call(self, event: Event) -> typing.Any:
return self.function(event)
@ -49,6 +49,7 @@ class EventHook(object):
self.parent = parent
self._children = {}
self._hooks = []
self._replayed = False
self._stored_events = []
self._context_hooks = {}
@ -57,10 +58,14 @@ class EventHook(object):
def _get_path(self) -> str:
path = []
parent = self
while not parent == None and not parent.name == None:
path.append(parent.name)
parent = parent.parent
parent = self # type: typing.Optional[EventHook]
while not parent == None:
cast_parent = typing.cast(EventHook, parent)
if cast_parent.name == None:
break
path.append(typing.cast(str, cast_parent.name))
parent = cast_parent.parent
return DEFAULT_EVENT_DELIMITER.join(path[::-1])
def new_context(self, context: str) -> "EventHookContext":
@ -72,8 +77,8 @@ class EventHook(object):
def _context_hook(self, context: str, function: CALLBACK_TYPE,
priority: int, replay: bool, kwargs: dict) -> EventCallback:
return self._hook(function, context, priority, replay, kwargs)
def _hook(self, function: CALLBACK_TYPE, context: str, priority: int,
replay: bool, kwargs: dict) -> EventCallback:
def _hook(self, function: CALLBACK_TYPE, context: typing.Optional[str],
priority: int, replay: bool, kwargs: dict) -> EventCallback:
callback = EventCallback(function, priority, kwargs)
if context == None:
@ -83,10 +88,10 @@ class EventHook(object):
self._context_hooks[context] = []
self._context_hooks[context].append(callback)
if replay and not self._stored_events == None:
if replay and not self._replayed:
for kwargs in self._stored_events:
self._call(kwargs, True, None)
self._stored_events = None
self._replayed = True
return callback
def unhook(self, callback: "EventHook"):
@ -102,58 +107,65 @@ class EventHook(object):
for context in empty:
del self._context_hooks[context]
def _make_multiple_hook(self, source: "EventHook", context: str,
events: typing.List[str]) -> "MultipleEventHook":
def _make_multiple_hook(self, source: "EventHook",
context: typing.Optional[str],
events: typing.Iterable[str]) -> "MultipleEventHook":
multiple_event_hook = MultipleEventHook()
for event in events:
event_hook = source.get_child(event)
if not context == None:
event_hook = event_hook.new_context(context)
multiple_event_hook._add(event_hook)
context_hook = event_hook.new_context(typing.cast(str, context))
multiple_event_hook._add(typing.cast(EventHook, context_hook))
else:
multiple_event_hook._add(event_hook)
return multiple_event_hook
def on(self, subevent: str, *extra_subevents,
delimiter: int = DEFAULT_EVENT_DELIMITER) -> "EventHook":
def on(self, subevent: str, *extra_subevents: str,
delimiter: str = DEFAULT_EVENT_DELIMITER) -> "EventHook":
return self._on(subevent, extra_subevents, None, delimiter)
def _context_on(self, context: str, subevent: str,
extra_subevents: typing.List[str],
extra_subevents: typing.Tuple[str, ...],
delimiter: str = DEFAULT_EVENT_DELIMITER) -> "EventHook":
return self._on(subevent, extra_subevents, context, delimiter)
def _on(self, subevent: str, extra_subevents: typing.List[str],
context: str, delimiter: str) -> "EventHook":
def _on(self, subevent: str, extra_subevents: typing.Tuple[str, ...],
context: typing.Optional[str], delimiter: str) -> "EventHook":
if delimiter in subevent:
event_chain = subevent.split(delimiter)
event_obj = self
for event_name in event_chain:
if DEFAULT_MULTI_DELIMITER in event_name:
return self._make_multiple_hook(event_obj, context,
multiple_hook = self._make_multiple_hook(event_obj, context,
event_name.split(DEFAULT_MULTI_DELIMITER))
return typing.cast(EventHook, multiple_hook)
event_obj = event_obj.get_child(event_name)
if not context == None:
return event_obj.new_context(context)
context_hook = event_obj.new_context(typing.cast(str, context))
return typing.cast(EventHook, context_hook)
return event_obj
if extra_subevents:
return self._make_multiple_hook(self, context,
multiple_hook = self._make_multiple_hook(self, context,
(subevent,)+extra_subevents)
return typing.cast(EventHook, multiple_hook)
child = self.get_child(subevent)
if not context == None:
child = child.new_context(context)
context_child = child.new_context(typing.cast(str, context))
child = typing.cast(EventHook, context_child)
return child
def call_for_result(self, default=None, **kwargs) -> typing.Any:
return (self.call_limited(1, **kwargs) or [default])[0]
def assure_call(self, **kwargs):
if not self._stored_events == None:
if not self._replayed:
self._stored_events.append(kwargs)
else:
self._call(kwargs, True, None)
def call(self, **kwargs) -> typing.List[typing.Any]:
return self._call(kwargs, True, None)
def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]:
return self._call(kwargs, True, None)
return self._call(kwargs, True, maximum)
def call_unsafe_for_result(self, default=None, **kwargs) -> typing.Any:
return (self.call_unsafe_limited(1, **kwargs) or [default])[0]
@ -163,7 +175,7 @@ class EventHook(object):
) -> typing.List[typing.Any]:
return self._call(kwargs, False, maximum)
def _call(self, kwargs: dict, safe: bool, maximum: int
def _call(self, kwargs: dict, safe: bool, maximum: typing.Optional[int]
) -> typing.List[typing.Any]:
event_path = self._get_path()
self.log.trace("calling event: \"%s\" (params: %s)",
@ -223,10 +235,10 @@ class EventHook(object):
def get_hooks(self) -> typing.List[EventCallback]:
return sorted(self._hooks + sum(self._context_hooks.values(), []),
key=lambda e: e.priority)
def get_children(self) -> typing.List["EventHook"]:
def get_children(self) -> typing.List[str]:
return list(self._children.keys())
def is_empty(self) -> bool:
return len(self.get_hooks() + self.get_children()) == 0
return (len(self.get_hooks())+len(self.get_children())) == 0
class MultipleEventHook(object):
def __init__(self):

View file

@ -12,7 +12,8 @@ class Exports(object):
self._add(None, setting, value)
def _context_add(self, context: str, setting: str, value: typing.Any):
self._add(context, setting, value)
def _add(self, context: str, setting: str, value: typing.Any):
def _add(self, context: typing.Optional[str], setting: str,
value: typing.Any):
if context == None:
if not setting in self._exports:
self._exports[setting] = []

View file

@ -36,7 +36,7 @@ class Bot(object):
self.lock.release()
def add_server(self, server_id: int, connect: bool = True
) -> typing.Optional[IRCServer.Server]:
) -> IRCServer.Server:
(_, alias, hostname, port, password, ipv4, tls, bindhost, nickname,
username, realname) = self.database.servers.get(server_id)
@ -44,7 +44,7 @@ class Bot(object):
hostname, port, password, ipv4, tls, bindhost, nickname, username,
realname)
if not new_server.get_setting("connect", True):
return
return new_server
self._events.on("new.server").call(server=new_server)
if connect and new_server.get_setting("connect", True):
self.connect(new_server)

View file

@ -1,5 +1,5 @@
import re, typing
from src import IRCBot, utils
from src import IRCBot, IRCServer, utils
class BufferLine(object):
def __init__(self, sender: str, message: str, action: bool, tags: dict,

View file

@ -9,9 +9,9 @@ class Channel(IRCObject.Object):
self.server = server
self.bot = bot
self.topic = ""
self.topic_setter_nickname = None
self.topic_setter_username = None
self.topic_setter_hostname = None
self.topic_setter_nickname = None # type: typing.Optional[str]
self.topic_setter_username = None # type: typing.Optional[str]
self.topic_setter_hostname = None # type: typing.Optional[str]
self.topic_time = 0
self.users = set([])
self.modes = {}

View file

@ -160,7 +160,7 @@ class Server(IRCObject.Object):
self.nickname = nickname
self.nickname_lower = utils.irc.lower(self.case_mapping, nickname)
def is_own_nickname(self, nickname: str):
return utils.irc.equals(self, nickname, self.nickname)
return utils.irc.equals(self.case_mapping, nickname, self.nickname)
def add_own_mode(self, mode: str, arg: str=None):
self.own_modes[mode] = arg

View file

@ -29,7 +29,7 @@ class User(IRCObject.Object):
def get_id(self)-> int:
return (self.identified_account_id_override or
self.identified_account_id or self._id)
def get_identified_account(self) -> str:
def get_identified_account(self) -> typing.Optional[str]:
return (self.identified_account_override or self.identified_account)
def set_nickname(self, nickname: str):

View file

@ -33,7 +33,7 @@ class Socket(object):
data = self._read_buffer+data
self._read_buffer = b""
if not self.delimiter == None:
data_split = data.split(delimiter)
data_split = data.split(self.delimiter)
if data_split[-1]:
self._read_buffer = data_split.pop(-1)
return [self._decode(data) for data in data_split]

View file

@ -2,8 +2,8 @@ import time, typing, uuid
from src import Database, EventManager, Logging
class Timer(object):
def __init__(self, id: int, context: str, name: str, delay: float,
next_due: float, kwargs: dict):
def __init__(self, id: str, context: typing.Optional[str], name: str,
delay: float, next_due: typing.Optional[float], kwargs: dict):
self.id = id
self.context = context
self.name = name
@ -46,7 +46,7 @@ class Timers(object):
def setup(self, timers: typing.List[typing.Tuple[str, dict]]):
for name, timer in timers:
id = name.split("timer-", 1)[1]
self._add(timer["name"], None, timer["delay"], timer[
self._add(None, timer["name"], timer["delay"], timer[
"next-due"], id, False, timer["kwargs"])
def _persist(self, timer: Timer):
@ -67,9 +67,10 @@ class Timers(object):
def add_persistent(self, name: str, delay: float, next_due: float=None,
**kwargs):
self._add(None, name, delay, next_due, None, True, kwargs)
def _add(self, context: str, name: str, delay: float, next_due: float,
id: str, persist: bool, kwargs: dict):
id = id or uuid.uuid4().hex
def _add(self, context: typing.Optional[str], name: str, delay: float,
next_due: typing.Optional[float], id: typing.Optional[str],
persist: bool, kwargs: dict):
id = id or str(uuid.uuid4())
timer = Timer(id, context, name, delay, next_due, kwargs)
if persist:
self._persist(timer)
@ -81,7 +82,7 @@ class Timers(object):
else:
self.timers.append(timer)
def next(self) -> float:
def next(self) -> typing.Optional[float]:
times = filter(None, [timer.time_left() for timer in self.get_timers()])
if not times:
return None

View file

@ -93,9 +93,9 @@ def parse_number(s: str) -> str:
pass
unit = s[-1].lower()
number = s[:-1]
number_str = s[:-1]
try:
number = decimal.Decimal(number)
number = decimal.Decimal(number_str)
except:
raise ValueError("Invalid format '%s' passed to parse_number" % number)
@ -155,7 +155,7 @@ def export(setting: str, value: typing.Any):
return _export_func
TOP_10_CALLABLE = typing.Callable[[typing.Any], typing.Any]
def top_10(items: typing.List[typing.Any],
def top_10(items: typing.Dict[typing.Any, typing.Any],
convert_key: TOP_10_CALLABLE=lambda x: x,
value_format: TOP_10_CALLABLE=lambda x: x):
top_10 = sorted(items.keys())

View file

@ -8,7 +8,7 @@ REGEX_HTTP = re.compile("https?://", re.I)
RESPONSE_MAX = (1024*1024)*100
class HTTPException:
class HTTPException(Exception):
pass
class HTTPTimeoutException(HTTPException):
pass
@ -52,7 +52,7 @@ def get_url(url: str, method: str="GET", get_params: dict={},
if soup:
soup = bs4.BeautifulSoup(response_content, parser)
if code:
return response.code, soup
return response.status_code, soup
return soup
data = response_content.decode(response.encoding or fallback_encoding)

View file

@ -53,7 +53,7 @@ def seperate_hostmask(hostmask: str) -> IRCHostmask:
return IRCHostmask(nickname, username, hostname, hostmask)
class IRCLine(object):
def __init__(self, tags: dict, prefix: str, command: str,
def __init__(self, tags: dict, prefix: typing.Optional[str], command: str,
args: typing.List[str], arbitrary: typing.Optional[str],
last: str):
self.tags = tags
@ -65,7 +65,7 @@ class IRCLine(object):
def parse_line(line: str) -> IRCLine:
tags = {}
prefix = None
prefix = typing.Optional[IRCHostmask]
command = None
if line[0] == "@":
@ -74,12 +74,12 @@ def parse_line(line: str) -> IRCLine:
tag, _, value = tag.partition("=")
tags[tag] = value
line, _, arbitrary = line.partition(" :")
arbitrary = arbitrary or None
line, _, arbitrary_split = line.partition(" :")
arbitrary = arbitrary_split or None
if line[0] == ":":
prefix, line = line[1:].split(" ", 1)
prefix = seperate_hostmask(prefix)
prefix_str, line = line[1:].split(" ", 1)
prefix = seperate_hostmask(prefix_str)
command, _, line = line.partition(" ")
args = line.split(" ")