diff --git a/src/utils/http.py b/src/utils/http.py index 3cae97ea..7da44c87 100644 --- a/src/utils/http.py +++ b/src/utils/http.py @@ -1,5 +1,5 @@ -import asyncio, codecs, ipaddress, re, signal, socket, traceback, typing -import urllib.error, urllib.parse, uuid +import asyncio, codecs, dataclasses, ipaddress, re, signal, socket, traceback +import typing, urllib.error, urllib.parse, uuid import json as _json import bs4, netifaces, requests, tornado.httpclient from src import IRCBot, utils @@ -53,46 +53,37 @@ class HTTPWrongContentTypeException(HTTPException): def throw_timeout(): raise HTTPTimeoutException() +@dataclasses.dataclass class Request(object): - def __init__(self, url: str, - get_params: typing.Dict[str, str]={}, post_data: typing.Any=None, - headers: typing.Dict[str, str]={}, + url: str + id: typing.Optional[str] = None + method: str = "GET" - json: bool=False, json_body: bool=False, allow_redirects: bool=True, - check_content_type: bool=True, parse: bool=False, - detect_encoding: bool=True, + get_params: typing.Dict[str, str] = dataclasses.field( + default_factory=dict) + post_data: typing.Any = None + headers: typing.Dict[str, str] = dataclasses.field( + default_factory=dict) + cookies: typing.Dict[str, str] = dataclasses.field( + default_factory=dict) - method: str="GET", parser: str="lxml", id: str=None, - fallback_encoding: str=None, content_type: str=None, - proxy: str=None, useragent: str=None, + json: bool = False + json_body: typing.Any = None - **kwargs): - self.id = id or str(uuid.uuid4()) + allow_redirects: bool = True + check_content_type: bool = True + parse: bool = False + detect_encoding: bool = True + parser: str = "lxml" + fallback_encoding: typing.Optional[str] = None + content_type: typing.Optional[str] = None + proxy: typing.Optional[str] = None + useragent: typing.Optional[str] = None - self.set_url(url) - self.method = method.upper() - self.get_params = get_params - self.post_data = post_data - self.headers = headers - - self.json = json - self.json_body = json_body - self.allow_redirects = allow_redirects - self.check_content_type = check_content_type - self.parse = parse - self.detect_encoding = detect_encoding - - self.parser = parser - self.fallback_encoding = fallback_encoding - self.content_type = content_type - self.proxy = proxy - self.useragent = useragent - - if kwargs: - if method == "POST": - self.post_data = kwargs - else: - self.get_params.update(kwargs) + def validate(self): + self.id = self.id or str(uuid.uuid4()) + self.set_url(self.url) + self.method = self.method.upper() def set_url(self, url: str): parts = urllib.parse.urlparse(url) @@ -166,7 +157,7 @@ def request(request_obj: typing.Union[str, Request], **kwargs) -> Response: return _request(request_obj) def _request(request_obj: Request) -> Response: - + request_obj.validate() def _wrap() -> Response: headers = request_obj.get_headers() response = requests.request( @@ -241,6 +232,7 @@ def request_many(requests: typing.List[Request]) -> typing.Dict[str, Response]: responses = {} async def _request(request): + request.validate() client = tornado.httpclient.AsyncHTTPClient() url = request.url if request.get_params: