add ModuleManager.try_reload_modules(), to try reloading in a transaction
if any of the modules fails to reload, rollback and use the already loaded modules. closes #179
This commit is contained in:
parent
4940aff877
commit
a9111c7241
4 changed files with 72 additions and 67 deletions
|
@ -71,26 +71,11 @@ class Module(ModuleManager.BaseModule):
|
||||||
:help: Reload all modules
|
:help: Reload all modules
|
||||||
:permission: reload-all-modules
|
:permission: reload-all-modules
|
||||||
"""
|
"""
|
||||||
success = []
|
result = self.bot.try_reload_modules()
|
||||||
fail = []
|
if result.success:
|
||||||
for name in list(self.bot.modules.modules.keys()):
|
event["stdout"].write(result.message)
|
||||||
try:
|
|
||||||
self.bot.modules.unload_module(name)
|
|
||||||
except ModuleManager.ModuleWarning:
|
|
||||||
continue
|
|
||||||
except:
|
|
||||||
fail.append(name)
|
|
||||||
load_success, load_fail = self.bot.load_modules(safe=True)
|
|
||||||
success.extend(load_success)
|
|
||||||
fail.extend(load_fail)
|
|
||||||
|
|
||||||
if success and fail:
|
|
||||||
event["stdout"].write("Reloaded %d modules, %d failed" % (
|
|
||||||
len(success), len(fail)))
|
|
||||||
elif fail:
|
|
||||||
event["stdout"].write("Failed to reload all modules")
|
|
||||||
else:
|
else:
|
||||||
event["stdout"].write("Reloaded %d modules" % len(success))
|
event["stderr"].write(result.message)
|
||||||
|
|
||||||
@utils.hook("received.command.enablemodule", min_args=1)
|
@utils.hook("received.command.enablemodule", min_args=1)
|
||||||
def enable(self, event):
|
def enable(self, event):
|
||||||
|
|
|
@ -58,18 +58,8 @@ class Module(ModuleManager.BaseModule):
|
||||||
def _reload_modules(self):
|
def _reload_modules(self):
|
||||||
self.bot.log.info("Reloading modules")
|
self.bot.log.info("Reloading modules")
|
||||||
|
|
||||||
success = []
|
result = self.bot.try_reload_modules()
|
||||||
fail = []
|
if result.success:
|
||||||
for name in list(self.bot.modules.modules.keys()):
|
self.bot.log.info(result.message)
|
||||||
try:
|
else:
|
||||||
self.bot.modules.unload_module(name)
|
self.bot.log.warn(result.message)
|
||||||
except ModuleManager.ModuleWarning:
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
failed.append(name)
|
|
||||||
continue
|
|
||||||
load_success, load_fail = self.bot.load_modules(safe=True)
|
|
||||||
fail.extend(load_fail)
|
|
||||||
|
|
||||||
self.bot.log.info("Reloaded %d modules (%d failed)",
|
|
||||||
[len(load_success), len(fail)])
|
|
||||||
|
|
|
@ -149,22 +149,27 @@ class Bot(object):
|
||||||
if throw:
|
if throw:
|
||||||
raise BitBotPanic()
|
raise BitBotPanic()
|
||||||
|
|
||||||
|
def _module_lists(self):
|
||||||
|
db_whitelist = set(self.get_setting("module-whitelist", []))
|
||||||
|
db_blacklist = set(self.get_setting("module-blacklist", []))
|
||||||
|
|
||||||
|
conf_whitelist = self.config.get("module-whitelist", "").split(",")
|
||||||
|
conf_blacklist = self.config.get("module-blacklist", "").split(",")
|
||||||
|
|
||||||
|
conf_whitelist = set(filter(None, conf_whitelist))
|
||||||
|
conf_blacklist = set(filter(None, conf_blacklist))
|
||||||
|
|
||||||
|
return (db_whitelist|conf_whitelist, db_blacklist|conf_blacklist)
|
||||||
|
|
||||||
def load_modules(self, safe: bool=False
|
def load_modules(self, safe: bool=False
|
||||||
) -> typing.Tuple[typing.List[str], typing.List[str]]:
|
) -> typing.Tuple[typing.List[str], typing.List[str]]:
|
||||||
db_blacklist = set(self.get_setting("module-blacklist", []))
|
whitelist, blacklist = self._module_lists()
|
||||||
db_whitelist = set(self.get_setting("module-whitelist", []))
|
|
||||||
|
|
||||||
conf_blacklist = self.config.get("module-blacklist", "").split(",")
|
|
||||||
conf_whitelist = self.config.get("module-whitelist", "").split(",")
|
|
||||||
|
|
||||||
conf_blacklist = set(filter(None, conf_blacklist))
|
|
||||||
conf_whitelist = set(filter(None, conf_whitelist))
|
|
||||||
|
|
||||||
blacklist = db_blacklist|conf_blacklist
|
|
||||||
whitelist = db_whitelist|conf_whitelist
|
|
||||||
|
|
||||||
return self.modules.load_modules(self, whitelist=whitelist,
|
return self.modules.load_modules(self, whitelist=whitelist,
|
||||||
blacklist=blacklist, safe=safe)
|
blacklist=blacklist, safe=safe)
|
||||||
|
def try_reload_modules(self) -> ModuleManager.TryReloadResult:
|
||||||
|
whitelist, blacklist = self._module_lists()
|
||||||
|
return self.modules.try_reload_modules(self, whitelist=whitelist,
|
||||||
|
blacklist=blacklist)
|
||||||
|
|
||||||
def add_server(self, server_id: int, connect: bool = True,
|
def add_server(self, server_id: int, connect: bool = True,
|
||||||
connection_param_args: typing.Dict[str, str]={}
|
connection_param_args: typing.Dict[str, str]={}
|
||||||
|
|
|
@ -36,6 +36,11 @@ class ModuleType(enum.Enum):
|
||||||
FILE = 0
|
FILE = 0
|
||||||
DIRECTORY = 1
|
DIRECTORY = 1
|
||||||
|
|
||||||
|
class TryReloadResult(object):
|
||||||
|
def __init__(self, success: bool, message: str):
|
||||||
|
self.success = success
|
||||||
|
self.message = message
|
||||||
|
|
||||||
class BaseModule(object):
|
class BaseModule(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
bot: "IRCBot.Bot",
|
bot: "IRCBot.Bot",
|
||||||
|
@ -155,8 +160,8 @@ class ModuleManager(object):
|
||||||
for directory in self.directories:
|
for directory in self.directories:
|
||||||
paths.append(os.path.join(directory, name))
|
paths.append(os.path.join(directory, name))
|
||||||
return paths
|
return paths
|
||||||
def _import_name(self, name: str) -> str:
|
def _import_name(self, name: str, context: str) -> str:
|
||||||
return "bitbot_%s" % name
|
return "%s_%s" % (name, context)
|
||||||
|
|
||||||
def from_context(self, context: str) -> typing.Optional[LoadedModule]:
|
def from_context(self, context: str) -> typing.Optional[LoadedModule]:
|
||||||
for module in self.modules.values():
|
for module in self.modules.values():
|
||||||
|
@ -198,7 +203,9 @@ class ModuleManager(object):
|
||||||
|
|
||||||
self._check_hashflags(bot, definition)
|
self._check_hashflags(bot, definition)
|
||||||
|
|
||||||
import_name = self._import_name(definition.name)
|
context = str(uuid.uuid4())
|
||||||
|
import_name = self._import_name(definition.name, context)
|
||||||
|
|
||||||
import_spec = importlib.util.spec_from_file_location(import_name,
|
import_spec = importlib.util.spec_from_file_location(import_name,
|
||||||
definition.filename)
|
definition.filename)
|
||||||
module = importlib.util.module_from_spec(import_spec)
|
module = importlib.util.module_from_spec(import_spec)
|
||||||
|
@ -214,7 +221,6 @@ class ModuleManager(object):
|
||||||
raise ModuleLoadException("module '%s' has a 'Module' attribute "
|
raise ModuleLoadException("module '%s' has a 'Module' attribute "
|
||||||
"but it is not a class." % definition.name)
|
"but it is not a class." % definition.name)
|
||||||
|
|
||||||
context = str(uuid.uuid4())
|
|
||||||
context_events = self.events.new_context(context)
|
context_events = self.events.new_context(context)
|
||||||
context_exports = self.exports.new_context(context)
|
context_exports = self.exports.new_context(context)
|
||||||
context_timers = self.timers.new_context(context)
|
context_timers = self.timers.new_context(context)
|
||||||
|
@ -316,39 +322,27 @@ class ModuleManager(object):
|
||||||
def load_modules(self, bot: "IRCBot.Bot", whitelist: typing.List[str]=[],
|
def load_modules(self, bot: "IRCBot.Bot", whitelist: typing.List[str]=[],
|
||||||
blacklist: typing.List[str]=[], safe: bool=False
|
blacklist: typing.List[str]=[], safe: bool=False
|
||||||
) -> typing.Tuple[typing.List[str], typing.List[str]]:
|
) -> typing.Tuple[typing.List[str], typing.List[str]]:
|
||||||
fail = []
|
|
||||||
success = []
|
|
||||||
|
|
||||||
loadable, nonloadable = self._list_valid_modules(bot, whitelist, blacklist)
|
loadable, nonloadable = self._list_valid_modules(bot, whitelist, blacklist)
|
||||||
|
|
||||||
for definition in nonloadable:
|
for definition in nonloadable:
|
||||||
self.log.warn("Not loading module '%s'", [definition.name])
|
self.log.warn("Not loading module '%s'", [definition.name])
|
||||||
|
|
||||||
for definition in loadable:
|
for definition in loadable:
|
||||||
try:
|
self.load_module(bot, definition)
|
||||||
self.load_module(bot, definition)
|
|
||||||
except ModuleWarning:
|
|
||||||
fail.append(definition.name)
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
if safe:
|
|
||||||
fail.append(definition.name)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
success.append(definition.name)
|
|
||||||
return success, fail
|
|
||||||
|
|
||||||
def unload_module(self, name: str):
|
def unload_module(self, name: str):
|
||||||
if not name in self.modules:
|
if not name in self.modules:
|
||||||
raise ModuleNotLoadedException(name)
|
raise ModuleNotLoadedException(name)
|
||||||
loaded_module = self.modules[name]
|
loaded_module = self.modules[name]
|
||||||
|
self._unload_module(loaded_module)
|
||||||
|
del self.modules[loaded_module.name]
|
||||||
|
|
||||||
|
def _unload_module(self, loaded_module: LoadedModule):
|
||||||
if hasattr(loaded_module.module, "unload"):
|
if hasattr(loaded_module.module, "unload"):
|
||||||
try:
|
try:
|
||||||
loaded_module.module.unload()
|
loaded_module.module.unload()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
del self.modules[loaded_module.name]
|
|
||||||
|
|
||||||
context = loaded_module.context
|
context = loaded_module.context
|
||||||
self.events.purge_context(context)
|
self.events.purge_context(context)
|
||||||
|
@ -378,6 +372,37 @@ class ModuleManager(object):
|
||||||
[loaded_module.name,
|
[loaded_module.name,
|
||||||
", ".join([str(referrer) for referrer in referrers])])
|
", ".join([str(referrer) for referrer in referrers])])
|
||||||
|
|
||||||
|
def try_reload_modules(self, bot: "IRCBot.Bot",
|
||||||
|
whitelist: typing.List[str], blacklist: typing.List[str]):
|
||||||
|
loadable, nonloadable = self._list_valid_modules(
|
||||||
|
bot, whitelist, blacklist)
|
||||||
|
|
||||||
|
old_modules = self.modules
|
||||||
|
self.modules = {}
|
||||||
|
|
||||||
|
failed = None
|
||||||
|
for definition in loadable:
|
||||||
|
try:
|
||||||
|
self.load_module(bot, definition)
|
||||||
|
except Exception as e:
|
||||||
|
failed = (definition, e)
|
||||||
|
break
|
||||||
|
|
||||||
|
if not failed == None:
|
||||||
|
for module in self.modules.values():
|
||||||
|
self._unload_module(module)
|
||||||
|
self.modules = old_modules
|
||||||
|
|
||||||
|
definition, exception = failed
|
||||||
|
return TryReloadResult(False,
|
||||||
|
"Failed to load %s (%s), rolling back reload" %
|
||||||
|
(definition.name, str(exception)))
|
||||||
|
else:
|
||||||
|
for module in old_modules.values():
|
||||||
|
self._unload_module(module)
|
||||||
|
return TryReloadResult(True, "Reloaded %d modules" %
|
||||||
|
len(self.modules.keys()))
|
||||||
|
|
||||||
def _list_valid_modules(self, bot: "IRCBot.Bot",
|
def _list_valid_modules(self, bot: "IRCBot.Bot",
|
||||||
whitelist: typing.List[str], blacklist: typing.List[str]):
|
whitelist: typing.List[str], blacklist: typing.List[str]):
|
||||||
module_definitions = self.list_modules()
|
module_definitions = self.list_modules()
|
||||||
|
|
Loading…
Reference in a new issue