diff --git a/modules/rest_api.py b/modules/rest_api.py index 1245bf83..9d86800f 100644 --- a/modules/rest_api.py +++ b/modules/rest_api.py @@ -43,6 +43,7 @@ class Response(object): def get_data(self): return self._data +_module = None _bot = None _events = None _log = None @@ -77,28 +78,9 @@ class Handler(http.server.BaseHTTPRequestHandler): def _minify_setting(self): return _bot.get_setting("rest-api-minify", False) - def url_for(self, headers, route, endpoint, get_params={}): - config_host = _bot.get_setting("rest-api-host", None) - - host = None - if not config_host == None: - host = config_host - elif "Host" in headers: - header_host, _, port = headers["Host"].partition(":") - if not port: - port = _bot.config.get("api-port", DEFAULT_PUBLIC_PORT) - host = "%s:%s" % (header_host, port) - - if host: - get_params_str = "" - if get_params: - get_params_str = "?%s" % urllib.parse.urlencode(get_params) - return "%s/%s/%s%s" % (host, route, endpoint, get_params_str) - else: - return None def _url_for(self, headers): return lambda route, endpoint, get_params={}: self.url_for( - headers, route, endpoint, get_params) + route, endpoint, get_params, headers.get("Host", None)) def _handle(self, method, path, endpoint, args): headers = utils.CaseInsensitiveDict(dict(self.headers.items())) @@ -174,6 +156,9 @@ class BitBotIPv6HTTPd(http.server.HTTPServer): utils.Setting("rest-api-host", "Public hostname:port for the REST API")) class Module(ModuleManager.BaseModule): def on_load(self): + global _module + _module = self + global _bot _bot = self.bot @@ -183,6 +168,8 @@ class Module(ModuleManager.BaseModule): global _log _log = self.log + self.exports.add("url-for", self._url_for) + self.httpd = None if self.bot.get_setting("rest-api", False): port = int(self.bot.config.get("api-port", str(DEFAULT_PORT))) @@ -211,3 +198,19 @@ class Module(ModuleManager.BaseModule): "permissions": event["args_split"][1:] }) event["stdout"].write("New API key ('%s'): %s" % (comment, api_key)) + + def _url_for(self, route, endpoint, get_params={}, host_override=None): + host = host_override or self.bot.get_setting("rest-api-host", None) + + host, _, port = host.partition(":") + if not port: + port = str(_bot.config.get("api-port", DEFAULT_PUBLIC_PORT)) + host = "%s:%s" % (host, port) + + if host: + get_params_str = "" + if get_params: + get_params_str = "?%s" % urllib.parse.urlencode(get_params) + return "%s/%s/%s%s" % (host, route, endpoint, get_params_str) + else: + return None