Type annotate scram.py and don't pass base64 data to scram.py functions
This commit is contained in:
parent
7e6446dc52
commit
6b4bb7cdba
2 changed files with 13 additions and 13 deletions
|
@ -74,10 +74,11 @@ class Module(ModuleManager.BaseModule):
|
||||||
auth_text = event["server"]._scram.client_first()
|
auth_text = event["server"]._scram.client_first()
|
||||||
else:
|
else:
|
||||||
current_scram = event["server"]._scram
|
current_scram = event["server"]._scram
|
||||||
|
data = base64.b64decode(event["message"])
|
||||||
if current_scram.state == scram.SCRAMState.ClientFirst:
|
if current_scram.state == scram.SCRAMState.ClientFirst:
|
||||||
auth_text = current_scram.server_first(event["message"])
|
auth_text = current_scram.server_first(data)
|
||||||
elif current_scram.state == scram.SCRAMState.ClientFinal:
|
elif current_scram.state == scram.SCRAMState.ClientFinal:
|
||||||
auth_text = current_scram.server_final(event["message"])
|
auth_text = current_scram.server_final(data)
|
||||||
del event["server"]._scram
|
del event["server"]._scram
|
||||||
|
|
||||||
if current_scram.state == scram.SCRAMState.VerifyFailed:
|
if current_scram.state == scram.SCRAMState.VerifyFailed:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import base64, enum, hashlib, hmac, uuid
|
import base64, enum, hashlib, hmac, typing, uuid
|
||||||
|
|
||||||
def _scram_nonce():
|
def _scram_nonce():
|
||||||
return uuid.uuid4().hex.encode("utf8")
|
return uuid.uuid4().hex.encode("utf8")
|
||||||
|
@ -30,26 +30,25 @@ class SCRAM(object):
|
||||||
self._salted_password = None
|
self._salted_password = None
|
||||||
self._auth_message = None
|
self._auth_message = None
|
||||||
|
|
||||||
def _get_data(self, message):
|
def _get_pieces(self, data: bytes) -> typing.Dict[bytes, bytes]:
|
||||||
data = base64.b64decode(message)
|
return dict(piece.split(b"=", 1) for piece in data.split(b","))
|
||||||
return data, dict(piece.split(b"=", 1) for piece in data.split(b","))
|
|
||||||
|
|
||||||
def _hmac(self, key, msg):
|
def _hmac(self, key: bytes, msg: bytes) -> bytes:
|
||||||
return hmac.digest(key, msg, self._algo)
|
return hmac.digest(key, msg, self._algo)
|
||||||
def _hash(self, msg):
|
def _hash(self, msg: bytes) -> bytes:
|
||||||
return hashlib.new(self._algo, msg).digest()
|
return hashlib.new(self._algo, msg).digest()
|
||||||
|
|
||||||
def client_first(self):
|
def client_first(self) -> bytes:
|
||||||
self.state = SCRAMState.ClientFirst
|
self.state = SCRAMState.ClientFirst
|
||||||
# start SCRAM handshake
|
# start SCRAM handshake
|
||||||
self._client_first = b"n=%s,r=%s" % (
|
self._client_first = b"n=%s,r=%s" % (
|
||||||
_scram_escape(self._username), _scram_nonce())
|
_scram_escape(self._username), _scram_nonce())
|
||||||
return b"n,,%s" % self._client_first
|
return b"n,,%s" % self._client_first
|
||||||
|
|
||||||
def server_first(self, message):
|
def server_first(self, data: bytes) -> bytes:
|
||||||
self.state = SCRAMState.ClientFinal
|
self.state = SCRAMState.ClientFinal
|
||||||
|
|
||||||
data, pieces = self._get_data(message)
|
pieces = self._get_pieces(data)
|
||||||
# server-first-message
|
# server-first-message
|
||||||
nonce = pieces[b"r"]
|
nonce = pieces[b"r"]
|
||||||
salt = base64.b64decode(pieces[b"s"])
|
salt = base64.b64decode(pieces[b"s"])
|
||||||
|
@ -74,9 +73,9 @@ class SCRAM(object):
|
||||||
|
|
||||||
return auth_noproof + (b",p=%s" % client_proof)
|
return auth_noproof + (b",p=%s" % client_proof)
|
||||||
|
|
||||||
def server_final(self, message):
|
def server_final(self, data: bytes) -> bytes:
|
||||||
# server-final-message
|
# server-final-message
|
||||||
data, pieces = self._get_data(message)
|
pieces = self._get_pieces(data)
|
||||||
verifier = pieces[b"v"]
|
verifier = pieces[b"v"]
|
||||||
|
|
||||||
server_key = self._hmac(self._salted_password, b"Server Key")
|
server_key = self._hmac(self._salted_password, b"Server Key")
|
||||||
|
|
Loading…
Reference in a new issue