Add TimersContext, to be able to purge timers when modules are unloaded

This commit is contained in:
jesopo 2018-10-12 17:54:15 +01:00
parent 278ab7d76f
commit cb94fa9ae4
2 changed files with 41 additions and 8 deletions

View file

@ -22,16 +22,18 @@ class ModuleNotLoadedWarning(ModuleWarning):
pass pass
class BaseModule(object): class BaseModule(object):
def __init__(self, bot, events, exports): def __init__(self, bot, events, exports, timers):
self.bot = bot self.bot = bot
self.events = events self.events = events
self.exports = exports self.exports = exports
self.timers = timers
class ModuleManager(object): class ModuleManager(object):
def __init__(self, events, exports, config, log, directory): def __init__(self, events, exports, timers, config, log, directory):
self.events = events self.events = events
self.exports = exports self.exports = exports
self.config = config self.config = config
self.timers = timers
self.log = log self.log = log
self.directory = directory self.directory = directory
@ -84,7 +86,9 @@ class ModuleManager(object):
context = str(uuid.uuid4()) context = str(uuid.uuid4())
context_events = self.events.new_context(context) context_events = self.events.new_context(context)
context_exports = self.exports.new_context(context) context_exports = self.exports.new_context(context)
module_object = module.Module(bot, context_events, context_exports) context_timers = self.timers.new_context(context)
module_object = module.Module(bot, context_events, context_exports,
context_timers)
if not hasattr(module_object, "_name"): if not hasattr(module_object, "_name"):
module_object._name = name.title() module_object._name = name.title()
@ -144,6 +148,7 @@ class ModuleManager(object):
context = module._context context = module._context
self.events.purge_context(context) self.events.purge_context(context)
self.exports.purge_context(context) self.exports.purge_context(context)
self.timers.purge_context(context)
del sys.modules[self._import_name(name)] del sys.modules[self._import_name(name)]
references = sys.getrefcount(module) references = sys.getrefcount(module)

View file

@ -27,12 +27,27 @@ class Timer(object):
def done(self): def done(self):
return self._done return self._done
class TimersContext(object):
def __init__(self, parent, context):
self._parent = parent
self.context = context
def add(self, name, delay, next_due=None, **kwargs):
self._parent._add(self.context, name, delay, next_due, None, False,
kwargs)
def add_persistent(self, name, delay, next_due=None, **kwargs):
self._parent._add(self.context, name, delay, next_due, None, True,
kwargs)
class Timers(object): class Timers(object):
def __init__(self, database, events, log): def __init__(self, database, events, log):
self.database = database self.database = database
self.events = events self.events = events
self.log = log self.log = log
self.timers = [] self.timers = []
self.context_timers = {}
def new_context(self, context):
return TimersContext(self, context)
def setup(self, timers): def setup(self, timers):
for name, timer in timers: for name, timer in timers:
@ -49,14 +64,20 @@ class Timers(object):
self.database.bot_settings.delete("timer-%s" % timer.id) self.database.bot_settings.delete("timer-%s" % timer.id)
def add(self, name, delay, next_due=None, **kwargs): def add(self, name, delay, next_due=None, **kwargs):
self._add(name, delay, next_due, None, False, kwargs) self._add(None, name, delay, next_due, None, False, kwargs)
def add_persistent(self, name, delay, next_due=None, **kwargs): def add_persistent(self, name, delay, next_due=None, **kwargs):
self._add(name, delay, next_due, None, True, kwargs) self._add(None, name, delay, next_due, None, True, kwargs)
def _add(self, name, delay, next_due, id, persist, kwargs): def _add(self, context, name, delay, next_due, id, persist, kwargs):
id = id or uuid.uuid4().hex id = id or uuid.uuid4().hex
timer = Timer(id, name, delay, next_due, kwargs) timer = Timer(id, name, delay, next_due, kwargs)
if persist: if persist:
self._persist(timer) self._persist(timer)
if context and not persist:
if not context in self.context_timers:
self.context_timers[context] = []
self.context_timers[context].append(timer)
else:
self.timers.append(timer) self.timers.append(timer)
def next(self): def next(self):
@ -65,11 +86,18 @@ class Timers(object):
return None return None
return max(min(times), 0) return max(min(times), 0)
def get_timers(self):
return self.timers + sum(self.context_timers.values(), [])
def call(self): def call(self):
for timer in self.timers[:]: for timer in self.get_timers():
if timer.due(): if timer.due():
timer.finish() timer.finish()
self.events.on("timer.%s" % timer.name).call(timer=timer, self.events.on("timer.%s" % timer.name).call(timer=timer,
**timer.kwargs) **timer.kwargs)
if timer.done(): if timer.done():
self._remove(timer) self._remove(timer)
def purge_context(self, context):
if context in self.context_timers:
del self.context_timers[context]