diff --git a/modules/modules.py b/modules/modules.py index 25a2c289..cac145ba 100644 --- a/modules/modules.py +++ b/modules/modules.py @@ -71,26 +71,11 @@ class Module(ModuleManager.BaseModule): :help: Reload all modules :permission: reload-all-modules """ - success = [] - fail = [] - for name in list(self.bot.modules.modules.keys()): - try: - self.bot.modules.unload_module(name) - except ModuleManager.ModuleWarning: - continue - except: - fail.append(name) - load_success, load_fail = self.bot.load_modules(safe=True) - success.extend(load_success) - fail.extend(load_fail) - - if success and fail: - event["stdout"].write("Reloaded %d modules, %d failed" % ( - len(success), len(fail))) - elif fail: - event["stdout"].write("Failed to reload all modules") + result = self.bot.try_reload_modules() + if result.success: + event["stdout"].write(result.message) else: - event["stdout"].write("Reloaded %d modules" % len(success)) + event["stderr"].write(result.message) @utils.hook("received.command.enablemodule", min_args=1) def enable(self, event): diff --git a/modules/signals.py b/modules/signals.py index 286852d6..f3adcbb0 100644 --- a/modules/signals.py +++ b/modules/signals.py @@ -58,18 +58,8 @@ class Module(ModuleManager.BaseModule): def _reload_modules(self): self.bot.log.info("Reloading modules") - success = [] - fail = [] - for name in list(self.bot.modules.modules.keys()): - try: - self.bot.modules.unload_module(name) - except ModuleManager.ModuleWarning: - continue - except Exception as e: - failed.append(name) - continue - load_success, load_fail = self.bot.load_modules(safe=True) - fail.extend(load_fail) - - self.bot.log.info("Reloaded %d modules (%d failed)", - [len(load_success), len(fail)]) + result = self.bot.try_reload_modules() + if result.success: + self.bot.log.info(result.message) + else: + self.bot.log.warn(result.message) diff --git a/src/IRCBot.py b/src/IRCBot.py index 8eb1873e..fbc74ad9 100644 --- a/src/IRCBot.py +++ b/src/IRCBot.py @@ -149,22 +149,27 @@ class Bot(object): if throw: raise BitBotPanic() + def _module_lists(self): + db_whitelist = set(self.get_setting("module-whitelist", [])) + db_blacklist = set(self.get_setting("module-blacklist", [])) + + conf_whitelist = self.config.get("module-whitelist", "").split(",") + conf_blacklist = self.config.get("module-blacklist", "").split(",") + + conf_whitelist = set(filter(None, conf_whitelist)) + conf_blacklist = set(filter(None, conf_blacklist)) + + return (db_whitelist|conf_whitelist, db_blacklist|conf_blacklist) + def load_modules(self, safe: bool=False ) -> typing.Tuple[typing.List[str], typing.List[str]]: - db_blacklist = set(self.get_setting("module-blacklist", [])) - db_whitelist = set(self.get_setting("module-whitelist", [])) - - conf_blacklist = self.config.get("module-blacklist", "").split(",") - conf_whitelist = self.config.get("module-whitelist", "").split(",") - - conf_blacklist = set(filter(None, conf_blacklist)) - conf_whitelist = set(filter(None, conf_whitelist)) - - blacklist = db_blacklist|conf_blacklist - whitelist = db_whitelist|conf_whitelist - + whitelist, blacklist = self._module_lists() return self.modules.load_modules(self, whitelist=whitelist, blacklist=blacklist, safe=safe) + def try_reload_modules(self) -> ModuleManager.TryReloadResult: + whitelist, blacklist = self._module_lists() + return self.modules.try_reload_modules(self, whitelist=whitelist, + blacklist=blacklist) def add_server(self, server_id: int, connect: bool = True, connection_param_args: typing.Dict[str, str]={} diff --git a/src/ModuleManager.py b/src/ModuleManager.py index bbfcc618..74ef7d01 100644 --- a/src/ModuleManager.py +++ b/src/ModuleManager.py @@ -36,6 +36,11 @@ class ModuleType(enum.Enum): FILE = 0 DIRECTORY = 1 +class TryReloadResult(object): + def __init__(self, success: bool, message: str): + self.success = success + self.message = message + class BaseModule(object): def __init__(self, bot: "IRCBot.Bot", @@ -155,8 +160,8 @@ class ModuleManager(object): for directory in self.directories: paths.append(os.path.join(directory, name)) return paths - def _import_name(self, name: str) -> str: - return "bitbot_%s" % name + def _import_name(self, name: str, context: str) -> str: + return "%s_%s" % (name, context) def from_context(self, context: str) -> typing.Optional[LoadedModule]: for module in self.modules.values(): @@ -198,7 +203,9 @@ class ModuleManager(object): self._check_hashflags(bot, definition) - import_name = self._import_name(definition.name) + context = str(uuid.uuid4()) + import_name = self._import_name(definition.name, context) + import_spec = importlib.util.spec_from_file_location(import_name, definition.filename) module = importlib.util.module_from_spec(import_spec) @@ -214,7 +221,6 @@ class ModuleManager(object): raise ModuleLoadException("module '%s' has a 'Module' attribute " "but it is not a class." % definition.name) - context = str(uuid.uuid4()) context_events = self.events.new_context(context) context_exports = self.exports.new_context(context) context_timers = self.timers.new_context(context) @@ -316,39 +322,27 @@ class ModuleManager(object): def load_modules(self, bot: "IRCBot.Bot", whitelist: typing.List[str]=[], blacklist: typing.List[str]=[], safe: bool=False ) -> typing.Tuple[typing.List[str], typing.List[str]]: - fail = [] - success = [] - loadable, nonloadable = self._list_valid_modules(bot, whitelist, blacklist) for definition in nonloadable: self.log.warn("Not loading module '%s'", [definition.name]) for definition in loadable: - try: - self.load_module(bot, definition) - except ModuleWarning: - fail.append(definition.name) - continue - except Exception as e: - if safe: - fail.append(definition.name) - continue - else: - raise - success.append(definition.name) - return success, fail + self.load_module(bot, definition) def unload_module(self, name: str): if not name in self.modules: raise ModuleNotLoadedException(name) loaded_module = self.modules[name] + self._unload_module(loaded_module) + del self.modules[loaded_module.name] + + def _unload_module(self, loaded_module: LoadedModule): if hasattr(loaded_module.module, "unload"): try: loaded_module.module.unload() except: pass - del self.modules[loaded_module.name] context = loaded_module.context self.events.purge_context(context) @@ -378,6 +372,37 @@ class ModuleManager(object): [loaded_module.name, ", ".join([str(referrer) for referrer in referrers])]) + def try_reload_modules(self, bot: "IRCBot.Bot", + whitelist: typing.List[str], blacklist: typing.List[str]): + loadable, nonloadable = self._list_valid_modules( + bot, whitelist, blacklist) + + old_modules = self.modules + self.modules = {} + + failed = None + for definition in loadable: + try: + self.load_module(bot, definition) + except Exception as e: + failed = (definition, e) + break + + if not failed == None: + for module in self.modules.values(): + self._unload_module(module) + self.modules = old_modules + + definition, exception = failed + return TryReloadResult(False, + "Failed to load %s (%s), rolling back reload" % + (definition.name, str(exception))) + else: + for module in old_modules.values(): + self._unload_module(module) + return TryReloadResult(True, "Reloaded %d modules" % + len(self.modules.keys())) + def _list_valid_modules(self, bot: "IRCBot.Bot", whitelist: typing.List[str], blacklist: typing.List[str]): module_definitions = self.list_modules()