Refactor ModuleManager
This commit is contained in:
parent
d0634bb54e
commit
552902d462
1 changed files with 83 additions and 46 deletions
|
@ -18,6 +18,11 @@ class ModuleUnloadException(ModuleException):
|
||||||
class ModuleNotLoadedWarning(ModuleWarning):
|
class ModuleNotLoadedWarning(ModuleWarning):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class ModuleDependencyNotFulfilled(ModuleException):
|
||||||
|
def __init__(self, message, dependency):
|
||||||
|
ModuleException.__init__(self, message)
|
||||||
|
self.dependency = dependency
|
||||||
|
|
||||||
class ModuleType(enum.Enum):
|
class ModuleType(enum.Enum):
|
||||||
FILE = 0
|
FILE = 0
|
||||||
DIRECTORY = 1
|
DIRECTORY = 1
|
||||||
|
@ -43,6 +48,23 @@ class BaseModule(object):
|
||||||
def command_line(self, args: str):
|
def command_line(self, args: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class ModuleDefinition(object):
|
||||||
|
def __init__(self,
|
||||||
|
name: str,
|
||||||
|
filename: str,
|
||||||
|
type: ModuleType,
|
||||||
|
hashflags: typing.List[typing.Tuple[str, str]]):
|
||||||
|
self.name = name
|
||||||
|
self.filename = filename
|
||||||
|
self.type = type
|
||||||
|
self.hashflags = hashflags
|
||||||
|
def get_dependencies(self):
|
||||||
|
dependencies = []
|
||||||
|
for key, value in self.hashflags:
|
||||||
|
if key == "depends-on":
|
||||||
|
dependencies.append(value)
|
||||||
|
return sorted(dependencies)
|
||||||
|
|
||||||
class LoadedModule(object):
|
class LoadedModule(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
name: str,
|
name: str,
|
||||||
|
@ -70,19 +92,39 @@ class ModuleManager(object):
|
||||||
self.directory = directory
|
self.directory = directory
|
||||||
|
|
||||||
self.modules = {} # type: typing.Dict[str, LoadedModule]
|
self.modules = {} # type: typing.Dict[str, LoadedModule]
|
||||||
self.waiting_requirement = {} # type: typing.Dict[str, typing.Set[str]]
|
|
||||||
|
|
||||||
def list_modules(self) -> typing.List[typing.Tuple[ModuleType, str]]:
|
def list_modules(self) -> typing.List[ModuleDefinition]:
|
||||||
modules = []
|
modules = []
|
||||||
|
|
||||||
for file_module in glob.glob(os.path.join(self.directory, "*.py")):
|
for file_module in glob.glob(os.path.join(self.directory, "*.py")):
|
||||||
modules.append((ModuleType.FILE, file_module))
|
modules.append(self.define_module(ModuleType.FILE, file_module))
|
||||||
|
|
||||||
for directory_module in glob.glob(os.path.join(
|
for directory_module in glob.glob(os.path.join(
|
||||||
self.directory, "*", "__init__.py")):
|
self.directory, "*", "__init__.py")):
|
||||||
directory = os.path.dirname(directory_module)
|
modules.append(self.define_module(ModuleType.DIRECTORY,
|
||||||
modules.append((ModuleType.DIRECTORY, directory))
|
directory_module))
|
||||||
return sorted(modules, key=lambda module: module[1])
|
return sorted(modules, key=lambda module: module.name)
|
||||||
|
|
||||||
|
def define_module(self, type: ModuleType, filename: str
|
||||||
|
) -> ModuleDefinition:
|
||||||
|
if type == ModuleType.DIRECTORY:
|
||||||
|
name = os.path.dirname(filename)
|
||||||
|
else:
|
||||||
|
name = filename
|
||||||
|
name = self._module_name(name)
|
||||||
|
|
||||||
|
return ModuleDefinition(name, filename, type,
|
||||||
|
utils.parse.hashflags(filename))
|
||||||
|
|
||||||
|
def find_module(self, name: str) -> ModuleDefinition:
|
||||||
|
type = ModuleType.FILE
|
||||||
|
path = self._module_path(name)
|
||||||
|
|
||||||
|
if os.path.isdir(path):
|
||||||
|
type = ModuleType.DIRECTORY
|
||||||
|
path = os.path.join(path, "__init__.py")
|
||||||
|
|
||||||
|
return self.define_module(type, path)
|
||||||
|
|
||||||
def _module_name(self, path: str) -> str:
|
def _module_name(self, path: str) -> str:
|
||||||
return os.path.basename(path).rsplit(".py", 1)[0].lower()
|
return os.path.basename(path).rsplit(".py", 1)[0].lower()
|
||||||
|
@ -107,15 +149,14 @@ class ModuleManager(object):
|
||||||
) -> typing.Any:
|
) -> typing.Any:
|
||||||
return getattr(obj, magic) if hasattr(obj, magic) else default
|
return getattr(obj, magic) if hasattr(obj, magic) else default
|
||||||
|
|
||||||
def _load_module(self, bot: "IRCBot.Bot", name: str) -> LoadedModule:
|
def _load_module(self, bot: "IRCBot.Bot", definition: ModuleDefinition
|
||||||
path = self._module_path(name)
|
) -> LoadedModule:
|
||||||
if os.path.isdir(path) and os.path.isfile(os.path.join(
|
dependencies = definition.get_dependencies()
|
||||||
path, "__init__.py")):
|
for dependency in dependencies:
|
||||||
path = os.path.join(path, "__init__.py")
|
if not dependency in self.modules:
|
||||||
else:
|
raise ModuleDependencyNotFulfilled(dependency)
|
||||||
path = "%s.py" % path
|
|
||||||
|
|
||||||
for hashflag, value in utils.parse.hashflags(path):
|
for hashflag, value in definition.hashflags:
|
||||||
if hashflag == "ignore":
|
if hashflag == "ignore":
|
||||||
# nope, ignore this module.
|
# nope, ignore this module.
|
||||||
raise ModuleNotLoadedWarning("module ignored")
|
raise ModuleNotLoadedWarning("module ignored")
|
||||||
|
@ -125,16 +166,9 @@ class ModuleManager(object):
|
||||||
# nope, required config option not present.
|
# nope, required config option not present.
|
||||||
raise ModuleNotLoadedWarning("required config not present")
|
raise ModuleNotLoadedWarning("required config not present")
|
||||||
|
|
||||||
elif hashflag == "require-module" and value:
|
import_name = self._import_name(definition.name)
|
||||||
requirement = value.lower()
|
import_spec = importlib.util.spec_from_file_location(import_name,
|
||||||
if not requirement in self.modules:
|
definition.filename)
|
||||||
if not requirement in self.waiting_requirement:
|
|
||||||
self.waiting_requirement[requirement] = set([])
|
|
||||||
self.waiting_requirement[requirement].add(path)
|
|
||||||
raise ModuleNotLoadedWarning("waiting for requirement")
|
|
||||||
|
|
||||||
import_name = self._import_name(name)
|
|
||||||
import_spec = importlib.util.spec_from_file_location(import_name, path)
|
|
||||||
module = importlib.util.module_from_spec(import_spec)
|
module = importlib.util.module_from_spec(import_spec)
|
||||||
sys.modules[import_name] = module
|
sys.modules[import_name] = module
|
||||||
loader = typing.cast(importlib.abc.Loader, import_spec.loader)
|
loader = typing.cast(importlib.abc.Loader, import_spec.loader)
|
||||||
|
@ -143,10 +177,10 @@ class ModuleManager(object):
|
||||||
module_object_pointer = getattr(module, "Module", None)
|
module_object_pointer = getattr(module, "Module", None)
|
||||||
if not module_object_pointer:
|
if not module_object_pointer:
|
||||||
raise ModuleLoadException("module '%s' doesn't have a "
|
raise ModuleLoadException("module '%s' doesn't have a "
|
||||||
"'Module' class." % name)
|
"'Module' class." % definition.name)
|
||||||
if not inspect.isclass(module_object_pointer):
|
if not inspect.isclass(module_object_pointer):
|
||||||
raise ModuleLoadException("module '%s' has a 'Module' attribute "
|
raise ModuleLoadException("module '%s' has a 'Module' attribute "
|
||||||
"but it is not a class." % name)
|
"but it is not a class." % definition.name)
|
||||||
|
|
||||||
context = str(uuid.uuid4())
|
context = str(uuid.uuid4())
|
||||||
context_events = self.events.new_context(context)
|
context_events = self.events.new_context(context)
|
||||||
|
@ -156,7 +190,7 @@ class ModuleManager(object):
|
||||||
context_exports, context_timers, self.log)
|
context_exports, context_timers, self.log)
|
||||||
|
|
||||||
if not hasattr(module_object, "_name"):
|
if not hasattr(module_object, "_name"):
|
||||||
module_object._name = name.title()
|
module_object._name = definition.name.title()
|
||||||
for attribute_name in dir(module_object):
|
for attribute_name in dir(module_object):
|
||||||
attribute = getattr(module_object, attribute_name)
|
attribute = getattr(module_object, attribute_name)
|
||||||
for hook in self._get_magic(attribute,
|
for hook in self._get_magic(attribute,
|
||||||
|
@ -167,28 +201,26 @@ class ModuleManager(object):
|
||||||
utils.consts.BITBOT_EXPORTS_MAGIC, []):
|
utils.consts.BITBOT_EXPORTS_MAGIC, []):
|
||||||
context_exports.add(export["setting"], export["value"])
|
context_exports.add(export["setting"], export["value"])
|
||||||
|
|
||||||
if name in self.modules:
|
if definition.name in self.modules:
|
||||||
raise ModuleNameCollisionException("Module name '%s' "
|
raise ModuleNameCollisionException("Module name '%s' "
|
||||||
"attempted to be used twice")
|
"attempted to be used twice" % definition.name)
|
||||||
|
|
||||||
return LoadedModule(name, module_object, context, import_name)
|
return LoadedModule(definition.name, module_object, context,
|
||||||
|
import_name)
|
||||||
|
|
||||||
def load_module(self, bot: "IRCBot.Bot", name: str) -> LoadedModule:
|
def load_module(self, bot: "IRCBot.Bot", definition: ModuleDefinition
|
||||||
|
) -> LoadedModule:
|
||||||
try:
|
try:
|
||||||
loaded_module = self._load_module(bot, name)
|
loaded_module = self._load_module(bot, definition)
|
||||||
except ModuleWarning as warning:
|
except ModuleWarning as warning:
|
||||||
self.log.warn("Module '%s' not loaded", [name])
|
self.log.warn("Module '%s' not loaded", [definition.name])
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log.error("Failed to load module \"%s\": %s",
|
self.log.error("Failed to load module \"%s\": %s",
|
||||||
[name, str(e)])
|
[definition.name, str(e)])
|
||||||
raise
|
raise
|
||||||
|
|
||||||
self.modules[loaded_module.name] = loaded_module
|
self.modules[loaded_module.name] = loaded_module
|
||||||
if loaded_module.name in self.waiting_requirement:
|
|
||||||
for requirement_name in self.waiting_requirement[
|
|
||||||
loaded_module.name]:
|
|
||||||
self.load_module(bot, requirement_name)
|
|
||||||
self.log.debug("Module '%s' loaded", [loaded_module.name])
|
self.log.debug("Module '%s' loaded", [loaded_module.name])
|
||||||
return loaded_module
|
return loaded_module
|
||||||
|
|
||||||
|
@ -197,21 +229,26 @@ class ModuleManager(object):
|
||||||
) -> typing.Tuple[typing.List[str], typing.List[str]]:
|
) -> typing.Tuple[typing.List[str], typing.List[str]]:
|
||||||
fail = []
|
fail = []
|
||||||
success = []
|
success = []
|
||||||
for type, path in self.list_modules():
|
|
||||||
name = self._module_name(path)
|
module_definitions = self.list_modules()
|
||||||
if name in whitelist or (not whitelist and not name in blacklist):
|
|
||||||
|
#TODO figure out dependency tree
|
||||||
|
|
||||||
|
for definition in module_definitions:
|
||||||
|
if definition.name in whitelist or (
|
||||||
|
not whitelist and not definition.name in blacklist):
|
||||||
try:
|
try:
|
||||||
self.load_module(bot, name)
|
self.load_module(bot, definition)
|
||||||
except ModuleWarning:
|
except ModuleWarning:
|
||||||
fail.append(name)
|
fail.append(definition.name)
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if safe:
|
if safe:
|
||||||
fail.append(name)
|
fail.append(definition.name)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
success.append(name)
|
success.append(definition.name)
|
||||||
return success, fail
|
return success, fail
|
||||||
|
|
||||||
def unload_module(self, name: str):
|
def unload_module(self, name: str):
|
||||||
|
|
Loading…
Reference in a new issue