refuse to unload core modules

This commit is contained in:
jesopo 2020-01-28 15:27:13 +00:00
parent a35ed3c6ea
commit 0744f8b63a

View file

@ -18,6 +18,8 @@ class ModuleLoadException(ModuleException):
pass pass
class ModuleUnloadException(ModuleException): class ModuleUnloadException(ModuleException):
pass pass
class ModuleCannotUnloadException(ModuleException):
pass
class ModuleNotLoadedWarning(ModuleWarning): class ModuleNotLoadedWarning(ModuleWarning):
pass pass
@ -70,16 +72,14 @@ class BaseModule(object):
os.makedirs(path) os.makedirs(path)
return os.path.join(path, filename) return os.path.join(path, filename)
@dataclasses.dataclass
class ModuleDefinition(object): class ModuleDefinition(object):
def __init__(self, name: str
name: str, filename: str
filename: str, type: ModuleType
type: ModuleType, hashflags: typing.List[typing.Tuple[str, typing.Optional[str]]]
hashflags: typing.List[typing.Tuple[str, typing.Optional[str]]]): is_core: bool
self.name = name
self.filename = filename
self.type = type
self.hashflags = hashflags
def get_dependencies(self): def get_dependencies(self):
dependencies = [] dependencies = []
for key, value in self.hashflags: for key, value in self.hashflags:
@ -87,18 +87,14 @@ class ModuleDefinition(object):
dependencies.append(value) dependencies.append(value)
return sorted(dependencies) return sorted(dependencies)
@dataclasses.dataclass
class LoadedModule(object): class LoadedModule(object):
def __init__(self, name: str
name: str, title: str
title: str, module: BaseModule
module: BaseModule, context: str
context: str, import_name: str
import_name: str): is_core: bool
self.name = name
self.title = title
self.module = module
self.context = context
self.import_name = import_name
class ModuleManager(object): class ModuleManager(object):
def __init__(self, def __init__(self,
@ -119,25 +115,27 @@ class ModuleManager(object):
self.modules = {} # type: typing.Dict[str, LoadedModule] self.modules = {} # type: typing.Dict[str, LoadedModule]
def _list_modules(self, directory: str def _list_modules(self, directory: str, is_core: bool
) -> typing.Dict[str, ModuleDefinition]: ) -> typing.Dict[str, ModuleDefinition]:
modules = [] modules = []
for file_module in glob.glob(os.path.join(directory, "*.py")): for file_module in glob.glob(os.path.join(directory, "*.py")):
modules.append(self.define_module(ModuleType.FILE, file_module)) modules.append(
self.define_module(ModuleType.FILE, file_module, is_core))
for directory_module in glob.glob(os.path.join( for directory_module in glob.glob(os.path.join(
directory, "*", "__init__.py")): directory, "*", "__init__.py")):
modules.append(self.define_module(ModuleType.DIRECTORY, modules.append(self.define_module(ModuleType.DIRECTORY,
directory_module)) directory_module, is_core))
return {definition.name: definition for definition in modules} return {definition.name: definition for definition in modules}
def list_modules(self, whitelist: typing.List[str], def list_modules(self, whitelist: typing.List[str],
blacklist: typing.List[str]) -> typing.Dict[str, ModuleDefinition]: blacklist: typing.List[str]) -> typing.Dict[str, ModuleDefinition]:
core_modules = self._list_modules(self._core_modules) core_modules = self._list_modules(self._core_modules, True)
extra_modules: typing.Dict[str, ModuleDefinition] = {} extra_modules: typing.Dict[str, ModuleDefinition] = {}
for directory in self._extra_modules: for directory in self._extra_modules:
for name, module in self._list_modules(directory).items(): for name, module in self._list_modules(directory, False).items():
if (not name in extra_modules and if (not name in extra_modules and
(name in whitelist or (name in whitelist or
(not whitelist and not name in blacklist))): (not whitelist and not name in blacklist))):
@ -148,7 +146,7 @@ class ModuleManager(object):
modules.update(core_modules) modules.update(core_modules)
return modules return modules
def define_module(self, type: ModuleType, filename: str def define_module(self, type: ModuleType, filename: str, is_core: bool,
) -> ModuleDefinition: ) -> ModuleDefinition:
if type == ModuleType.DIRECTORY: if type == ModuleType.DIRECTORY:
name = os.path.dirname(filename) name = os.path.dirname(filename)
@ -157,14 +155,13 @@ class ModuleManager(object):
name = self._module_name(name) name = self._module_name(name)
return ModuleDefinition(name, filename, type, return ModuleDefinition(name, filename, type,
utils.parse.hashflags(filename)) utils.parse.hashflags(filename), is_core)
def find_module(self, name: str) -> ModuleDefinition: def find_module(self, name: str) -> ModuleDefinition:
type = ModuleType.FILE type = ModuleType.FILE
paths = self._module_paths(name) paths = self._module_paths(name)
path = None for is_core, possible_path in paths:
for possible_path in paths:
if os.path.isdir(possible_path): if os.path.isdir(possible_path):
type = ModuleType.DIRECTORY type = ModuleType.DIRECTORY
possible_path = os.path.join(possible_path, "__init__.py") possible_path = os.path.join(possible_path, "__init__.py")
@ -172,20 +169,17 @@ class ModuleManager(object):
possible_path = "%s.py" % possible_path possible_path = "%s.py" % possible_path
if os.path.isfile(possible_path): if os.path.isfile(possible_path):
path = possible_path return self.define_module(type, possible_path, is_core)
break
if not path:
raise ModuleNotFoundException(name) raise ModuleNotFoundException(name)
return self.define_module(type, path)
def _module_name(self, path: str) -> str: def _module_name(self, path: str) -> str:
return os.path.basename(path).rsplit(".py", 1)[0].lower() return os.path.basename(path).rsplit(".py", 1)[0].lower()
def _module_paths(self, name: str) -> typing.List[str]: def _module_paths(self, name: str) -> typing.List[typing.Tuple[bool, str]]:
paths = [] paths = []
for directory in [self._core_modules]+self._extra_modules:
paths.append(os.path.join(directory, name)) for i, directory in enumerate([self._core_modules]+self._extra_modules):
paths.append((i==0, os.path.join(directory, name)))
return paths return paths
def _import_name(self, name: str, context: str) -> str: def _import_name(self, name: str, context: str) -> str:
return "%s_%s" % (name, context) return "%s_%s" % (name, context)
@ -275,7 +269,7 @@ class ModuleManager(object):
context_exports.add(key, value) context_exports.add(key, value)
return LoadedModule(definition.name, module_title, module_object, return LoadedModule(definition.name, module_title, module_object,
context, import_name) context, import_name, definition.is_core)
def load_module(self, bot: "IRCBot.Bot", definition: ModuleDefinition def load_module(self, bot: "IRCBot.Bot", definition: ModuleDefinition
) -> LoadedModule: ) -> LoadedModule:
@ -364,6 +358,9 @@ class ModuleManager(object):
del self.modules[loaded_module.name] del self.modules[loaded_module.name]
def _unload_module(self, loaded_module: LoadedModule): def _unload_module(self, loaded_module: LoadedModule):
if loaded_module.is_core:
raise ModuleCannotUnloadException("cannot unload core modules")
if hasattr(loaded_module.module, "unload"): if hasattr(loaded_module.module, "unload"):
try: try:
loaded_module.module.unload() loaded_module.module.unload()