Remove mention of ipv4 - detect address family automatically

This commit is contained in:
jesopo 2019-05-21 10:11:09 +01:00
parent 9a5ba753d0
commit ab1074cf4d
5 changed files with 16 additions and 22 deletions

View file

@ -9,16 +9,16 @@ class Table(object):
class Servers(Table): class Servers(Table):
def add(self, alias: str, hostname: str, port: int, password: str, def add(self, alias: str, hostname: str, port: int, password: str,
ipv4: bool, tls: bool, bindhost: str, tls: bool, bindhost: str, nickname: str, username: str=None,
nickname: str, username: str=None, realname: str=None): realname: str=None):
username = username or nickname username = username or nickname
realname = realname or nickname realname = realname or nickname
self.database.execute( self.database.execute(
"""INSERT INTO servers (alias, hostname, port, password, tls, """INSERT INTO servers (alias, hostname, port, password, tls,
ipv4, bindhost, nickname, username, realname) VALUES ( bindhost, nickname, username, realname) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
[alias, hostname, port, password, tls, ipv4, bindhost, nickname, [alias, hostname, port, password, tls, bindhost, nickname, username,
username, realname]) realname])
return self.database.execute_fetchone( return self.database.execute_fetchone(
"SELECT server_id FROM servers ORDER BY server_id DESC LIMIT 1")[0] "SELECT server_id FROM servers ORDER BY server_id DESC LIMIT 1")[0]
def get_all(self): def get_all(self):
@ -29,7 +29,7 @@ class Servers(Table):
typing.Optional[str], typing.Optional[str]]: typing.Optional[str], typing.Optional[str]]:
return self.database.execute_fetchone( return self.database.execute_fetchone(
"""SELECT server_id, alias, hostname, port, password, tls, """SELECT server_id, alias, hostname, port, password, tls,
ipv4, bindhost, nickname, username, realname FROM servers WHERE bindhost, nickname, username, realname FROM servers WHERE
server_id=?""", server_id=?""",
[id]) [id])
def get_by_alias(self, alias: str) -> typing.Optional[int]: def get_by_alias(self, alias: str) -> typing.Optional[int]:
@ -352,7 +352,7 @@ class Database(object):
if not self.has_table("servers"): if not self.has_table("servers"):
self.execute("""CREATE TABLE servers self.execute("""CREATE TABLE servers
(server_id INTEGER PRIMARY KEY, alias TEXT, hostname TEXT, (server_id INTEGER PRIMARY KEY, alias TEXT, hostname TEXT,
port INTEGER, password TEXT, ipv4 BOOLEAN, tls BOOLEAN, port INTEGER, password TEXT, tls BOOLEAN,
bindhost TEXT, nickname TEXT, username TEXT, realname TEXT, bindhost TEXT, nickname TEXT, username TEXT, realname TEXT,
UNIQUE (alias))""") UNIQUE (alias))""")
def make_channels_table(self): def make_channels_table(self):

View file

@ -83,7 +83,6 @@ class Server(IRCObject.Object):
self.get_setting("fallback-encoding", "iso-8859-1"), self.get_setting("fallback-encoding", "iso-8859-1"),
self.connection_params.hostname, self.connection_params.hostname,
self.connection_params.port, self.connection_params.port,
self.connection_params.ipv4,
self.connection_params.bindhost, self.connection_params.bindhost,
self.connection_params.tls, self.connection_params.tls,
tls_verify=self.get_setting("ssl-verify", True), tls_verify=self.get_setting("ssl-verify", True),

View file

@ -7,7 +7,7 @@ UNTHROTTLED_MAX_LINES = 10
class Socket(IRCObject.Object): class Socket(IRCObject.Object):
def __init__(self, log: Logging.Log, encoding: str, fallback_encoding: str, def __init__(self, log: Logging.Log, encoding: str, fallback_encoding: str,
hostname: str, port: int, ipv4: bool, bindhost: str, tls: bool, hostname: str, port: int, bindhost: str, tls: bool,
tls_verify: bool=True, cert: str=None, key: str=None): tls_verify: bool=True, cert: str=None, key: str=None):
self.log = log self.log = log
@ -15,7 +15,6 @@ class Socket(IRCObject.Object):
self._fallback_encoding = fallback_encoding self._fallback_encoding = fallback_encoding
self._hostname = hostname self._hostname = hostname
self._port = port self._port = port
self._ipv4 = ipv4
self._bindhost = bindhost self._bindhost = bindhost
self._tls = tls self._tls = tls
@ -53,17 +52,15 @@ class Socket(IRCObject.Object):
hostname=server_hostname) hostname=server_hostname)
def connect(self): def connect(self):
family = socket.AF_INET if self._ipv4 else socket.AF_INET6 bindhost = None
self._socket = socket.socket(family, socket.SOCK_STREAM)
self._socket.settimeout(5.0)
if self._bindhost: if self._bindhost:
self._socket.bind((self._bindhost, 0)) bindhost = (self._bindhost, 0)
self._socket = socket.create_connection((self._hostname, self._port),
5.0, bindhost)
if self._tls: if self._tls:
self._tls_wrap() self._tls_wrap()
self._socket.connect((self._hostname, self._port))
self.connected_ip = self._socket.getpeername()[0] self.connected_ip = self._socket.getpeername()[0]
self.cached_fileno = self._socket.fileno() self.cached_fileno = self._socket.fileno()
self.connected = True self.connected = True

View file

@ -10,11 +10,10 @@ def add_server(database: "Database.Database"):
port = int(input("port: ")) port = int(input("port: "))
tls = bool_input("tls?") tls = bool_input("tls?")
password = input("password?: ") password = input("password?: ")
ipv4 = bool_input("ipv4?")
nickname = input("nickname: ") nickname = input("nickname: ")
username = input("username: ") username = input("username: ")
realname = input("realname: ") realname = input("realname: ")
bindhost = input("bindhost?: ") bindhost = input("bindhost?: ")
server_id = database.servers.add(alias, hostname, port, password, ipv4, tls, server_id = database.servers.add(alias, hostname, port, password, tls,
bindhost, nickname, username, realname) bindhost, nickname, username, realname)

View file

@ -226,15 +226,14 @@ def parse_format(s: str) -> str:
OPT_STR = typing.Optional[str] OPT_STR = typing.Optional[str]
class IRCConnectionParameters(object): class IRCConnectionParameters(object):
def __init__(self, id: int, alias: str, hostname: str, port: int, def __init__(self, id: int, alias: str, hostname: str, port: int,
password: OPT_STR, tls: bool, ipv4: bool, bindhost: OPT_STR, password: OPT_STR, tls: bool, bindhost: OPT_STR, nickname: str,
nickname: str, username: OPT_STR, realname: OPT_STR, username: OPT_STR, realname: OPT_STR,
args: typing.Dict[str, str]={}): args: typing.Dict[str, str]={}):
self.id = id self.id = id
self.alias = alias self.alias = alias
self.hostname = hostname self.hostname = hostname
self.port = port self.port = port
self.tls = tls self.tls = tls
self.ipv4 = ipv4
self.bindhost = bindhost self.bindhost = bindhost
self.password = password self.password = password
self.nickname = nickname self.nickname = nickname