Refactor IRCServer .connect() logic

This commit is contained in:
jesopo 2018-11-05 11:53:33 +00:00
parent 5963580cb8
commit c2ebc7b5e4
2 changed files with 15 additions and 22 deletions

View file

@ -43,11 +43,13 @@ class Bot(object):
new_server = IRCServer.Server(self, self._events, server_id, alias, new_server = IRCServer.Server(self, self._events, server_id, alias,
hostname, port, password, ipv4, tls, bindhost, nickname, username, hostname, port, password, ipv4, tls, bindhost, nickname, username,
realname) realname)
if not new_server.get_setting("connect", True):
return new_server
self._events.on("new.server").call(server=new_server) self._events.on("new.server").call(server=new_server)
if connect and new_server.get_setting("connect", True):
self.connect(new_server) if not connect or not new_server.get_setting("connect", True):
return new_server
self.connect(new_server)
return new_server return new_server
def add_socket(self, sock: socket.socket): def add_socket(self, sock: socket.socket):

View file

@ -58,22 +58,6 @@ class Server(IRCObject.Object):
self.attempted_join = {} # type: typing.Dict[str, typing.Optional[str]] self.attempted_join = {} # type: typing.Dict[str, typing.Optional[str]]
self.ping_sent = False self.ping_sent = False
if ipv4:
self.socket = socket.socket(socket.AF_INET,
socket.SOCK_STREAM)
else:
self.socket = socket.socket(socket.AF_INET6,
socket.SOCK_STREAM)
if bindhost:
self.socket.bind((bindhost, 0))
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
self.socket.settimeout(5.0)
if self.tls:
self.tls_wrap()
self.cached_fileno = self.socket.fileno()
self.events.on("timer.rejoin").hook(self.try_rejoin) self.events.on("timer.rejoin").hook(self.try_rejoin)
def __repr__(self): def __repr__(self):
@ -84,8 +68,7 @@ class Server(IRCObject.Object):
return "%s:%s%s" % (self.target_hostname, "+" if self.tls else "", return "%s:%s%s" % (self.target_hostname, "+" if self.tls else "",
self.port) self.port)
def fileno(self): def fileno(self):
fileno = self.socket.fileno() return self.socket.fileno()
return self.cached_fileno if fileno == -1 else fileno
def tls_wrap(self): def tls_wrap(self):
context = ssl.SSLContext(ssl.PROTOCOL_TLS) context = ssl.SSLContext(ssl.PROTOCOL_TLS)
@ -105,6 +88,14 @@ class Server(IRCObject.Object):
self.socket = context.wrap_socket(self.socket) self.socket = context.wrap_socket(self.socket)
def connect(self): def connect(self):
family = self.AF_INET if self.ipv4 else socket.AF_INET6
self.socket = socket.socket(family, socket.SOCK_STREAM)
self.socket.settimeout(5.0)
if bindhost:
self.socket.bind((bindhost, 0))
if self.tls:
self.tls_wrap()
self.socket.connect((self.target_hostname, self.port)) self.socket.connect((self.target_hostname, self.port))
self.send_capibility_ls() self.send_capibility_ls()