Type annotate scram.py and don't pass base64 data to scram.py functions

This commit is contained in:
jesopo 2019-02-06 08:50:19 +00:00
parent 7e6446dc52
commit 6b4bb7cdba
2 changed files with 13 additions and 13 deletions

View file

@ -74,10 +74,11 @@ class Module(ModuleManager.BaseModule):
auth_text = event["server"]._scram.client_first() auth_text = event["server"]._scram.client_first()
else: else:
current_scram = event["server"]._scram current_scram = event["server"]._scram
data = base64.b64decode(event["message"])
if current_scram.state == scram.SCRAMState.ClientFirst: if current_scram.state == scram.SCRAMState.ClientFirst:
auth_text = current_scram.server_first(event["message"]) auth_text = current_scram.server_first(data)
elif current_scram.state == scram.SCRAMState.ClientFinal: elif current_scram.state == scram.SCRAMState.ClientFinal:
auth_text = current_scram.server_final(event["message"]) auth_text = current_scram.server_final(data)
del event["server"]._scram del event["server"]._scram
if current_scram.state == scram.SCRAMState.VerifyFailed: if current_scram.state == scram.SCRAMState.VerifyFailed:

View file

@ -1,4 +1,4 @@
import base64, enum, hashlib, hmac, uuid import base64, enum, hashlib, hmac, typing, uuid
def _scram_nonce(): def _scram_nonce():
return uuid.uuid4().hex.encode("utf8") return uuid.uuid4().hex.encode("utf8")
@ -30,26 +30,25 @@ class SCRAM(object):
self._salted_password = None self._salted_password = None
self._auth_message = None self._auth_message = None
def _get_data(self, message): def _get_pieces(self, data: bytes) -> typing.Dict[bytes, bytes]:
data = base64.b64decode(message) return dict(piece.split(b"=", 1) for piece in data.split(b","))
return data, dict(piece.split(b"=", 1) for piece in data.split(b","))
def _hmac(self, key, msg): def _hmac(self, key: bytes, msg: bytes) -> bytes:
return hmac.digest(key, msg, self._algo) return hmac.digest(key, msg, self._algo)
def _hash(self, msg): def _hash(self, msg: bytes) -> bytes:
return hashlib.new(self._algo, msg).digest() return hashlib.new(self._algo, msg).digest()
def client_first(self): def client_first(self) -> bytes:
self.state = SCRAMState.ClientFirst self.state = SCRAMState.ClientFirst
# start SCRAM handshake # start SCRAM handshake
self._client_first = b"n=%s,r=%s" % ( self._client_first = b"n=%s,r=%s" % (
_scram_escape(self._username), _scram_nonce()) _scram_escape(self._username), _scram_nonce())
return b"n,,%s" % self._client_first return b"n,,%s" % self._client_first
def server_first(self, message): def server_first(self, data: bytes) -> bytes:
self.state = SCRAMState.ClientFinal self.state = SCRAMState.ClientFinal
data, pieces = self._get_data(message) pieces = self._get_pieces(data)
# server-first-message # server-first-message
nonce = pieces[b"r"] nonce = pieces[b"r"]
salt = base64.b64decode(pieces[b"s"]) salt = base64.b64decode(pieces[b"s"])
@ -74,9 +73,9 @@ class SCRAM(object):
return auth_noproof + (b",p=%s" % client_proof) return auth_noproof + (b",p=%s" % client_proof)
def server_final(self, message): def server_final(self, data: bytes) -> bytes:
# server-final-message # server-final-message
data, pieces = self._get_data(message) pieces = self._get_pieces(data)
verifier = pieces[b"v"] verifier = pieces[b"v"]
server_key = self._hmac(self._salted_password, b"Server Key") server_key = self._hmac(self._salted_password, b"Server Key")