diff --git a/modules/rest_api.py b/modules/rest_api.py index c0a83f6e..e081cf7c 100644 --- a/modules/rest_api.py +++ b/modules/rest_api.py @@ -7,6 +7,41 @@ import http.server, json, socket, ssl, threading, uuid, urllib.parse from src import ModuleManager, utils +class Response(object): + def __init__(self, compact=False): + self._compact = compact + self._headers = {} + self._data = b"" + self.code = 200 + self.content_type = "text/plain" + def write(self, data): + self._data += data + def write_text(self, data): + self._data += data.encode("utf8") + def write_json(self, obj): + if self._compact: + data = json.dumps(event_response, sort_keys=True, + separators=(",", ":")) + else: + data = json.dumps(event_response, sort_keys=True, indent=4) + self._data += data.encode("utf8") + + def set_header(self, key: str, value: str): + self._headers[key] = value + def get_headers(self): + headers = {} + has_content_type = False + for key, value in self._headers.items(): + if key.lower() == "content-type": + has_content_type = True + headers[key] = value + if not has_content_type: + headers["Content-Type"] = self.content_type + return headers + + def get_data(self): + return self._data + _bot = None _events = None _log = None @@ -29,12 +64,12 @@ class Handler(http.server.BaseHTTPRequestHandler): content_length = int(self.headers.get("content-length", 0)) return self.rfile.read(content_length) - def _respond(self, code, headers, data): - self.send_response(code) - for key, value in headers.items(): + def _respond(self, response): + self.send_response(response.code) + for key, value in response.get_headers().items(): self.send_header(key, value) self.end_headers() - self.wfile.write(data.encode("utf8")) + self.wfile.write(response.get_data()) def _get_settings(self, key): key_setting = _bot.get_setting("api-key-%s" % key, {}) @@ -51,9 +86,8 @@ class Handler(http.server.BaseHTTPRequestHandler): params = self._url_params() data = self._body() - response = "" - code = 404 - content_type = "text/plain" + response = Response() + response.code = 404 hooks = _events.on("api").on(method).on(endpoint).get_hooks() if hooks: @@ -74,29 +108,20 @@ class Handler(http.server.BaseHTTPRequestHandler): event_response = _bot.trigger(lambda: _events.on("api").on(method).on( endpoint).call_for_result_unsafe(params=params, - path=args, data=data, headers=headers)) + path=args, data=data, headers=headers, + response=response)) except Exception as e: _log.error("failed to call API endpoint \"%s\"", [path], exc_info=True) - code = 500 + response.code = 500 if not event_response == None: - content_type = "application/json" - if minify: - response = json.dumps(event_response, - sort_keys=True, separators=(",", ":")) - else: - response = json.dumps(event_response, - sort_keys=True, indent=4) - code = 200 + response.write_json(event_response) + response.content_type = "application/json" else: - code = 401 + response.code = 401 - headers = { - "Content-type": content_type - } - - self._respond(code, headers, response) + self._respond(response) _log.debug("[HTTP] finishing _handle for %s from %s:%d (%d)", [method, self.client_address[0], self.client_address[1], code])