Refactor ModuleManager

This commit is contained in:
jesopo 2019-05-25 18:24:50 +01:00
parent d0634bb54e
commit 552902d462

View file

@ -18,6 +18,11 @@ class ModuleUnloadException(ModuleException):
class ModuleNotLoadedWarning(ModuleWarning): class ModuleNotLoadedWarning(ModuleWarning):
pass pass
class ModuleDependencyNotFulfilled(ModuleException):
def __init__(self, message, dependency):
ModuleException.__init__(self, message)
self.dependency = dependency
class ModuleType(enum.Enum): class ModuleType(enum.Enum):
FILE = 0 FILE = 0
DIRECTORY = 1 DIRECTORY = 1
@ -43,6 +48,23 @@ class BaseModule(object):
def command_line(self, args: str): def command_line(self, args: str):
pass pass
class ModuleDefinition(object):
def __init__(self,
name: str,
filename: str,
type: ModuleType,
hashflags: typing.List[typing.Tuple[str, str]]):
self.name = name
self.filename = filename
self.type = type
self.hashflags = hashflags
def get_dependencies(self):
dependencies = []
for key, value in self.hashflags:
if key == "depends-on":
dependencies.append(value)
return sorted(dependencies)
class LoadedModule(object): class LoadedModule(object):
def __init__(self, def __init__(self,
name: str, name: str,
@ -70,19 +92,39 @@ class ModuleManager(object):
self.directory = directory self.directory = directory
self.modules = {} # type: typing.Dict[str, LoadedModule] self.modules = {} # type: typing.Dict[str, LoadedModule]
self.waiting_requirement = {} # type: typing.Dict[str, typing.Set[str]]
def list_modules(self) -> typing.List[typing.Tuple[ModuleType, str]]: def list_modules(self) -> typing.List[ModuleDefinition]:
modules = [] modules = []
for file_module in glob.glob(os.path.join(self.directory, "*.py")): for file_module in glob.glob(os.path.join(self.directory, "*.py")):
modules.append((ModuleType.FILE, file_module)) modules.append(self.define_module(ModuleType.FILE, file_module))
for directory_module in glob.glob(os.path.join( for directory_module in glob.glob(os.path.join(
self.directory, "*", "__init__.py")): self.directory, "*", "__init__.py")):
directory = os.path.dirname(directory_module) modules.append(self.define_module(ModuleType.DIRECTORY,
modules.append((ModuleType.DIRECTORY, directory)) directory_module))
return sorted(modules, key=lambda module: module[1]) return sorted(modules, key=lambda module: module.name)
def define_module(self, type: ModuleType, filename: str
) -> ModuleDefinition:
if type == ModuleType.DIRECTORY:
name = os.path.dirname(filename)
else:
name = filename
name = self._module_name(name)
return ModuleDefinition(name, filename, type,
utils.parse.hashflags(filename))
def find_module(self, name: str) -> ModuleDefinition:
type = ModuleType.FILE
path = self._module_path(name)
if os.path.isdir(path):
type = ModuleType.DIRECTORY
path = os.path.join(path, "__init__.py")
return self.define_module(type, path)
def _module_name(self, path: str) -> str: def _module_name(self, path: str) -> str:
return os.path.basename(path).rsplit(".py", 1)[0].lower() return os.path.basename(path).rsplit(".py", 1)[0].lower()
@ -107,15 +149,14 @@ class ModuleManager(object):
) -> typing.Any: ) -> typing.Any:
return getattr(obj, magic) if hasattr(obj, magic) else default return getattr(obj, magic) if hasattr(obj, magic) else default
def _load_module(self, bot: "IRCBot.Bot", name: str) -> LoadedModule: def _load_module(self, bot: "IRCBot.Bot", definition: ModuleDefinition
path = self._module_path(name) ) -> LoadedModule:
if os.path.isdir(path) and os.path.isfile(os.path.join( dependencies = definition.get_dependencies()
path, "__init__.py")): for dependency in dependencies:
path = os.path.join(path, "__init__.py") if not dependency in self.modules:
else: raise ModuleDependencyNotFulfilled(dependency)
path = "%s.py" % path
for hashflag, value in utils.parse.hashflags(path): for hashflag, value in definition.hashflags:
if hashflag == "ignore": if hashflag == "ignore":
# nope, ignore this module. # nope, ignore this module.
raise ModuleNotLoadedWarning("module ignored") raise ModuleNotLoadedWarning("module ignored")
@ -125,16 +166,9 @@ class ModuleManager(object):
# nope, required config option not present. # nope, required config option not present.
raise ModuleNotLoadedWarning("required config not present") raise ModuleNotLoadedWarning("required config not present")
elif hashflag == "require-module" and value: import_name = self._import_name(definition.name)
requirement = value.lower() import_spec = importlib.util.spec_from_file_location(import_name,
if not requirement in self.modules: definition.filename)
if not requirement in self.waiting_requirement:
self.waiting_requirement[requirement] = set([])
self.waiting_requirement[requirement].add(path)
raise ModuleNotLoadedWarning("waiting for requirement")
import_name = self._import_name(name)
import_spec = importlib.util.spec_from_file_location(import_name, path)
module = importlib.util.module_from_spec(import_spec) module = importlib.util.module_from_spec(import_spec)
sys.modules[import_name] = module sys.modules[import_name] = module
loader = typing.cast(importlib.abc.Loader, import_spec.loader) loader = typing.cast(importlib.abc.Loader, import_spec.loader)
@ -143,10 +177,10 @@ class ModuleManager(object):
module_object_pointer = getattr(module, "Module", None) module_object_pointer = getattr(module, "Module", None)
if not module_object_pointer: if not module_object_pointer:
raise ModuleLoadException("module '%s' doesn't have a " raise ModuleLoadException("module '%s' doesn't have a "
"'Module' class." % name) "'Module' class." % definition.name)
if not inspect.isclass(module_object_pointer): if not inspect.isclass(module_object_pointer):
raise ModuleLoadException("module '%s' has a 'Module' attribute " raise ModuleLoadException("module '%s' has a 'Module' attribute "
"but it is not a class." % name) "but it is not a class." % definition.name)
context = str(uuid.uuid4()) context = str(uuid.uuid4())
context_events = self.events.new_context(context) context_events = self.events.new_context(context)
@ -156,7 +190,7 @@ class ModuleManager(object):
context_exports, context_timers, self.log) context_exports, context_timers, self.log)
if not hasattr(module_object, "_name"): if not hasattr(module_object, "_name"):
module_object._name = name.title() module_object._name = definition.name.title()
for attribute_name in dir(module_object): for attribute_name in dir(module_object):
attribute = getattr(module_object, attribute_name) attribute = getattr(module_object, attribute_name)
for hook in self._get_magic(attribute, for hook in self._get_magic(attribute,
@ -167,28 +201,26 @@ class ModuleManager(object):
utils.consts.BITBOT_EXPORTS_MAGIC, []): utils.consts.BITBOT_EXPORTS_MAGIC, []):
context_exports.add(export["setting"], export["value"]) context_exports.add(export["setting"], export["value"])
if name in self.modules: if definition.name in self.modules:
raise ModuleNameCollisionException("Module name '%s' " raise ModuleNameCollisionException("Module name '%s' "
"attempted to be used twice") "attempted to be used twice" % definition.name)
return LoadedModule(name, module_object, context, import_name) return LoadedModule(definition.name, module_object, context,
import_name)
def load_module(self, bot: "IRCBot.Bot", name: str) -> LoadedModule: def load_module(self, bot: "IRCBot.Bot", definition: ModuleDefinition
) -> LoadedModule:
try: try:
loaded_module = self._load_module(bot, name) loaded_module = self._load_module(bot, definition)
except ModuleWarning as warning: except ModuleWarning as warning:
self.log.warn("Module '%s' not loaded", [name]) self.log.warn("Module '%s' not loaded", [definition.name])
raise raise
except Exception as e: except Exception as e:
self.log.error("Failed to load module \"%s\": %s", self.log.error("Failed to load module \"%s\": %s",
[name, str(e)]) [definition.name, str(e)])
raise raise
self.modules[loaded_module.name] = loaded_module self.modules[loaded_module.name] = loaded_module
if loaded_module.name in self.waiting_requirement:
for requirement_name in self.waiting_requirement[
loaded_module.name]:
self.load_module(bot, requirement_name)
self.log.debug("Module '%s' loaded", [loaded_module.name]) self.log.debug("Module '%s' loaded", [loaded_module.name])
return loaded_module return loaded_module
@ -197,21 +229,26 @@ class ModuleManager(object):
) -> typing.Tuple[typing.List[str], typing.List[str]]: ) -> typing.Tuple[typing.List[str], typing.List[str]]:
fail = [] fail = []
success = [] success = []
for type, path in self.list_modules():
name = self._module_name(path) module_definitions = self.list_modules()
if name in whitelist or (not whitelist and not name in blacklist):
#TODO figure out dependency tree
for definition in module_definitions:
if definition.name in whitelist or (
not whitelist and not definition.name in blacklist):
try: try:
self.load_module(bot, name) self.load_module(bot, definition)
except ModuleWarning: except ModuleWarning:
fail.append(name) fail.append(definition.name)
continue continue
except Exception as e: except Exception as e:
if safe: if safe:
fail.append(name) fail.append(definition.name)
continue continue
else: else:
raise raise
success.append(name) success.append(definition.name)
return success, fail return success, fail
def unload_module(self, name: str): def unload_module(self, name: str):