Support IRCv3's tls/STARTTLS

This commit is contained in:
jesopo 2018-09-07 16:34:51 +01:00
parent 208a323d48
commit a4f0d1bf28
3 changed files with 36 additions and 11 deletions

View file

@ -65,11 +65,7 @@ class Server(object):
self.socket.settimeout(5.0) self.socket.settimeout(5.0)
if self.tls: if self.tls:
context = ssl.SSLContext(OUR_TLS_PROTOCOL) self.tls_wrap()
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_TLSv1
self.socket = context.wrap_socket(self.socket)
self.cached_fileno = self.socket.fileno() self.cached_fileno = self.socket.fileno()
self.events.on("timer").on("rejoin").hook(self.try_rejoin) self.events.on("timer").on("rejoin").hook(self.try_rejoin)
@ -82,6 +78,13 @@ class Server(object):
fileno = self.socket.fileno() fileno = self.socket.fileno()
return self.cached_fileno if fileno == -1 else fileno return self.cached_fileno if fileno == -1 else fileno
def tls_wrap(self):
context = ssl.SSLContext(OUR_TLS_PROTOCOL)
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_TLSv1
self.socket = context.wrap_socket(self.socket)
def connect(self): def connect(self):
self.socket.connect((self.target_hostname, self.port)) self.socket.connect((self.target_hostname, self.port))
self.send_capibility_ls() self.send_capibility_ls()
@ -297,6 +300,8 @@ class Server(object):
self.send("CAP END") self.send("CAP END")
def send_authenticate(self, text): def send_authenticate(self, text):
self.send("AUTHENTICATE %s" % text) self.send("AUTHENTICATE %s" % text)
def send_starttls(self):
self.send("STARTTLS")
def waiting_for_capabilities(self): def waiting_for_capabilities(self):
return bool(len(self._capabilities_waiting)) return bool(len(self._capabilities_waiting))

View file

@ -3,17 +3,11 @@ import base64
class Module(object): class Module(object):
def __init__(self, bot, events, exports): def __init__(self, bot, events, exports):
self.bot = bot self.bot = bot
events.on("preprocess.connect").hook(self.preprocess_connect)
events.on("received.cap.ls").hook(self.on_cap) events.on("received.cap.ls").hook(self.on_cap)
events.on("received.cap.ack").hook(self.on_cap_ack) events.on("received.cap.ack").hook(self.on_cap_ack)
events.on("received.authenticate").hook(self.on_authenticate) events.on("received.authenticate").hook(self.on_authenticate)
events.on("received.numeric.903").hook(self.sasl_success) events.on("received.numeric.903").hook(self.sasl_success)
def preprocess_connect(self, event):
sasl = event["server"].get_setting("sasl")
if sasl:
event["server"].send_capability_request("sasl")
def on_cap(self, event): def on_cap(self, event):
has_sasl = "sasl" in event["capabilities"] has_sasl = "sasl" in event["capabilities"]
has_mechanisms = has_sasl and not event["capabilities"]["sasl" has_mechanisms = has_sasl and not event["capabilities"]["sasl"

26
modules/starttls.py Normal file
View file

@ -0,0 +1,26 @@
import base64
class Module(object):
def __init__(self, bot, events, exports):
self.bot = bot
events.on("received.cap.ls").hook(self.on_cap)
events.on("received.cap.ack").hook(self.on_cap_ack)
events.on("received.numeric.670").hook(self.starttls_success)
events.on("received.numeric.691").hook(self.starttls_failed)
def on_cap(self, event):
if "tls" in event["capabilities"].keys() and not event["server"].tls:
event["server"].queue_capability("tls")
def on_cap_ack(self, event):
if "tls" in event["capabilities"].keys():
event["server"].send_starttls()
event["server"].wait_for_capability("tls")
def starttls_success(self, event):
event["server"].wrap_tls()
event["server"].capability_done("tls")
def starttls_failed(self, event):
event["server"].capability_done("tls")