experimental support for rss bindhost

This commit is contained in:
jesopo 2020-04-09 15:47:20 +01:00
parent 7fdb9a1e55
commit b19e956f68
3 changed files with 22 additions and 3 deletions

View file

@ -7,10 +7,15 @@ import feedparser
RSS_INTERVAL = 60 # 1 minute RSS_INTERVAL = 60 # 1 minute
SETTING_BIND = utils.Setting("rss-bindhost",
"Which local address to bind to for RSS requests", example="127.0.0.1")
@utils.export("botset", utils.IntSetting("rss-interval", @utils.export("botset", utils.IntSetting("rss-interval",
"Interval (in seconds) between RSS polls", example="120")) "Interval (in seconds) between RSS polls", example="120"))
@utils.export("channelset", utils.BoolSetting("rss-shorten", @utils.export("channelset", utils.BoolSetting("rss-shorten",
"Whether or not to shorten RSS urls")) "Whether or not to shorten RSS urls"))
@utils.export("serverset", SETTING_BIND)
@utils.export("channelset", SETTING_BIND)
class Module(ModuleManager.BaseModule): class Module(ModuleManager.BaseModule):
_name = "RSS" _name = "RSS"
def on_load(self): def on_load(self):
@ -57,8 +62,10 @@ class Module(ModuleManager.BaseModule):
return return
requests = [] requests = []
for url in hooks.keys(): for url, (server, channel) in hooks.items():
requests.append(utils.http.Request(url, id=url)) bindhost = channel.get_setting("rss-bindhost",
server.get_setting("rss-bindhost", None))
requests.append(utils.http.Request(url, id=url, bindhost=bindhost))
pages = utils.http.request_many(requests) pages = utils.http.request_many(requests)

View file

@ -15,3 +15,4 @@ scrypt ==0.8.13
suds-jurko ==0.6 suds-jurko ==0.6
tornado ==6.0.3 tornado ==6.0.3
tweepy ==3.8.0 tweepy ==3.8.0
requests-toolbelt ==0.9.1

View file

@ -3,6 +3,7 @@ import typing, urllib.error, urllib.parse, uuid
import json as _json import json as _json
import bs4, netifaces, requests, tornado.httpclient import bs4, netifaces, requests, tornado.httpclient
from src import IRCBot, utils from src import IRCBot, utils
from requests_toolbelt.adapters import source
REGEX_URL = re.compile("https?://\S+", re.I) REGEX_URL = re.compile("https?://\S+", re.I)
@ -78,6 +79,8 @@ class Request(object):
timeout: int=5 timeout: int=5
bindhost: typing.Optional[str] = None
def validate(self): def validate(self):
self.id = self.id or str(uuid.uuid4()) self.id = self.id or str(uuid.uuid4())
self.set_url(self.url) self.set_url(self.url)
@ -189,11 +192,17 @@ def _request(request_obj: Request) -> Response:
redirect = 0 redirect = 0
current_url = request_obj.url current_url = request_obj.url
session = requests.Session()
if not request_obj.bindhost is None:
new_source = source.SourceAddressAdapter(request_obj.bindhost)
session.mount('http://', new_source)
session.mount('https://', new_source)
while True: while True:
if request_obj.check_hostname: if request_obj.check_hostname:
_assert_allowed(current_url) _assert_allowed(current_url)
response = requests.request( response = session.request(
request_obj.method, request_obj.method,
current_url, current_url,
headers=headers, headers=headers,
@ -218,6 +227,8 @@ def _request(request_obj: Request) -> Response:
raise ValueError("Response too large") raise ValueError("Response too large")
break break
session.close()
headers = utils.CaseInsensitiveDict(dict(response.headers)) headers = utils.CaseInsensitiveDict(dict(response.headers))
our_response = Response(response.status_code, response_content, our_response = Response(response.status_code, response_content,
encoding=response.encoding, headers=headers, encoding=response.encoding, headers=headers,