From ab1074cf4d91224b35fbbc217be4d39d5d5029a0 Mon Sep 17 00:00:00 2001 From: jesopo Date: Tue, 21 May 2019 10:11:09 +0100 Subject: [PATCH] Remove mention of `ipv4` - detect address family automatically --- src/Database.py | 14 +++++++------- src/IRCServer.py | 1 - src/IRCSocket.py | 15 ++++++--------- src/utils/cli.py | 3 +-- src/utils/irc/__init__.py | 5 ++--- 5 files changed, 16 insertions(+), 22 deletions(-) diff --git a/src/Database.py b/src/Database.py index 67fe4a6f..6c276754 100644 --- a/src/Database.py +++ b/src/Database.py @@ -9,16 +9,16 @@ class Table(object): class Servers(Table): def add(self, alias: str, hostname: str, port: int, password: str, - ipv4: bool, tls: bool, bindhost: str, - nickname: str, username: str=None, realname: str=None): + tls: bool, bindhost: str, nickname: str, username: str=None, + realname: str=None): username = username or nickname realname = realname or nickname self.database.execute( """INSERT INTO servers (alias, hostname, port, password, tls, - ipv4, bindhost, nickname, username, realname) VALUES ( + bindhost, nickname, username, realname) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - [alias, hostname, port, password, tls, ipv4, bindhost, nickname, - username, realname]) + [alias, hostname, port, password, tls, bindhost, nickname, username, + realname]) return self.database.execute_fetchone( "SELECT server_id FROM servers ORDER BY server_id DESC LIMIT 1")[0] def get_all(self): @@ -29,7 +29,7 @@ class Servers(Table): typing.Optional[str], typing.Optional[str]]: return self.database.execute_fetchone( """SELECT server_id, alias, hostname, port, password, tls, - ipv4, bindhost, nickname, username, realname FROM servers WHERE + bindhost, nickname, username, realname FROM servers WHERE server_id=?""", [id]) def get_by_alias(self, alias: str) -> typing.Optional[int]: @@ -352,7 +352,7 @@ class Database(object): if not self.has_table("servers"): self.execute("""CREATE TABLE servers (server_id INTEGER PRIMARY KEY, alias TEXT, hostname TEXT, - port INTEGER, password TEXT, ipv4 BOOLEAN, tls BOOLEAN, + port INTEGER, password TEXT, tls BOOLEAN, bindhost TEXT, nickname TEXT, username TEXT, realname TEXT, UNIQUE (alias))""") def make_channels_table(self): diff --git a/src/IRCServer.py b/src/IRCServer.py index bd01cbfb..3a4ff182 100644 --- a/src/IRCServer.py +++ b/src/IRCServer.py @@ -83,7 +83,6 @@ class Server(IRCObject.Object): self.get_setting("fallback-encoding", "iso-8859-1"), self.connection_params.hostname, self.connection_params.port, - self.connection_params.ipv4, self.connection_params.bindhost, self.connection_params.tls, tls_verify=self.get_setting("ssl-verify", True), diff --git a/src/IRCSocket.py b/src/IRCSocket.py index b3b17b6d..3c32f64f 100644 --- a/src/IRCSocket.py +++ b/src/IRCSocket.py @@ -7,7 +7,7 @@ UNTHROTTLED_MAX_LINES = 10 class Socket(IRCObject.Object): def __init__(self, log: Logging.Log, encoding: str, fallback_encoding: str, - hostname: str, port: int, ipv4: bool, bindhost: str, tls: bool, + hostname: str, port: int, bindhost: str, tls: bool, tls_verify: bool=True, cert: str=None, key: str=None): self.log = log @@ -15,7 +15,6 @@ class Socket(IRCObject.Object): self._fallback_encoding = fallback_encoding self._hostname = hostname self._port = port - self._ipv4 = ipv4 self._bindhost = bindhost self._tls = tls @@ -53,17 +52,15 @@ class Socket(IRCObject.Object): hostname=server_hostname) def connect(self): - family = socket.AF_INET if self._ipv4 else socket.AF_INET6 - self._socket = socket.socket(family, socket.SOCK_STREAM) - - self._socket.settimeout(5.0) - + bindhost = None if self._bindhost: - self._socket.bind((self._bindhost, 0)) + bindhost = (self._bindhost, 0) + self._socket = socket.create_connection((self._hostname, self._port), + 5.0, bindhost) + if self._tls: self._tls_wrap() - self._socket.connect((self._hostname, self._port)) self.connected_ip = self._socket.getpeername()[0] self.cached_fileno = self._socket.fileno() self.connected = True diff --git a/src/utils/cli.py b/src/utils/cli.py index 638ff749..1a8e81e8 100644 --- a/src/utils/cli.py +++ b/src/utils/cli.py @@ -10,11 +10,10 @@ def add_server(database: "Database.Database"): port = int(input("port: ")) tls = bool_input("tls?") password = input("password?: ") - ipv4 = bool_input("ipv4?") nickname = input("nickname: ") username = input("username: ") realname = input("realname: ") bindhost = input("bindhost?: ") - server_id = database.servers.add(alias, hostname, port, password, ipv4, tls, + server_id = database.servers.add(alias, hostname, port, password, tls, bindhost, nickname, username, realname) diff --git a/src/utils/irc/__init__.py b/src/utils/irc/__init__.py index a97b9479..14bbb5b0 100644 --- a/src/utils/irc/__init__.py +++ b/src/utils/irc/__init__.py @@ -226,15 +226,14 @@ def parse_format(s: str) -> str: OPT_STR = typing.Optional[str] class IRCConnectionParameters(object): def __init__(self, id: int, alias: str, hostname: str, port: int, - password: OPT_STR, tls: bool, ipv4: bool, bindhost: OPT_STR, - nickname: str, username: OPT_STR, realname: OPT_STR, + password: OPT_STR, tls: bool, bindhost: OPT_STR, nickname: str, + username: OPT_STR, realname: OPT_STR, args: typing.Dict[str, str]={}): self.id = id self.alias = alias self.hostname = hostname self.port = port self.tls = tls - self.ipv4 = ipv4 self.bindhost = bindhost self.password = password self.nickname = nickname