Skip to content

Commit 68d1f89

Browse files
committed
Added support for RSASSA-PSS algorithms (PS256, PS384, PS512)
1 parent 2f4c770 commit 68d1f89

File tree

3 files changed

+117
-14
lines changed

3 files changed

+117
-14
lines changed

jwt/algorithms.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ def get_default_algorithms():
4242
'RS512': RSAAlgorithm(RSAAlgorithm.SHA512),
4343
'ES256': ECAlgorithm(ECAlgorithm.SHA256),
4444
'ES384': ECAlgorithm(ECAlgorithm.SHA384),
45-
'ES512': ECAlgorithm(ECAlgorithm.SHA512)
45+
'ES512': ECAlgorithm(ECAlgorithm.SHA512),
46+
'PS256': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
47+
'PS384': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
48+
'PS512': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512)
4649
})
4750

4851
return default_algorithms
@@ -145,7 +148,7 @@ class RSAAlgorithm(Algorithm):
145148
SHA512 = hashes.SHA512
146149

147150
def __init__(self, hash_alg):
148-
self.hash_alg = hash_alg()
151+
self.hash_alg = hash_alg
149152

150153
def prepare_key(self, key):
151154
if isinstance(key, RSAPrivateKey) or \
@@ -171,7 +174,7 @@ def prepare_key(self, key):
171174
def sign(self, msg, key):
172175
signer = key.signer(
173176
padding.PKCS1v15(),
174-
self.hash_alg
177+
self.hash_alg()
175178
)
176179

177180
signer.update(msg)
@@ -181,7 +184,7 @@ def verify(self, msg, key, sig):
181184
verifier = key.verifier(
182185
sig,
183186
padding.PKCS1v15(),
184-
self.hash_alg
187+
self.hash_alg()
185188
)
186189

187190
verifier.update(msg)
@@ -202,7 +205,7 @@ class ECAlgorithm(Algorithm):
202205
SHA512 = hashes.SHA512
203206

204207
def __init__(self, hash_alg):
205-
self.hash_alg = hash_alg()
208+
self.hash_alg = hash_alg
206209

207210
def prepare_key(self, key):
208211
if isinstance(key, EllipticCurvePrivateKey) or \
@@ -227,13 +230,48 @@ def prepare_key(self, key):
227230
return key
228231

229232
def sign(self, msg, key):
230-
signer = key.signer(ec.ECDSA(self.hash_alg))
233+
signer = key.signer(ec.ECDSA(self.hash_alg()))
231234

232235
signer.update(msg)
233236
return signer.finalize()
234237

235238
def verify(self, msg, key, sig):
236-
verifier = key.verifier(sig, ec.ECDSA(self.hash_alg))
239+
verifier = key.verifier(sig, ec.ECDSA(self.hash_alg()))
240+
241+
verifier.update(msg)
242+
243+
try:
244+
verifier.verify()
245+
return True
246+
except InvalidSignature:
247+
return False
248+
249+
class RSAPSSAlgorithm(RSAAlgorithm):
250+
"""
251+
Performs a signature using RSASSA-PSS with MGF1
252+
"""
253+
254+
def sign(self, msg, key):
255+
signer = key.signer(
256+
padding.PSS(
257+
mgf=padding.MGF1(self.hash_alg()),
258+
salt_length=padding.PSS.MAX_LENGTH
259+
),
260+
self.hash_alg()
261+
)
262+
263+
signer.update(msg)
264+
return signer.finalize()
265+
266+
def verify(self, msg, key, sig):
267+
verifier = key.verifier(
268+
sig,
269+
padding.PSS(
270+
mgf=padding.MGF1(self.hash_alg()),
271+
salt_length=padding.PSS.MAX_LENGTH
272+
),
273+
self.hash_alg()
274+
)
237275

238276
verifier.update(msg)
239277

tests/test_algorithms.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .utils import ensure_bytes, ensure_unicode, key_path
88

99
try:
10-
from jwt.algorithms import RSAAlgorithm, ECAlgorithm
10+
from jwt.algorithms import RSAAlgorithm, ECAlgorithm, RSAPSSAlgorithm
1111

1212
has_crypto = True
1313
except ImportError:
@@ -169,34 +169,92 @@ def test_ec_should_accept_unicode_key(self):
169169
def test_ec_verify_should_return_false_if_signature_invalid(self):
170170
algo = ECAlgorithm(ECAlgorithm.SHA256)
171171

172-
jwt_message = ensure_bytes('Hello World!')
172+
message = ensure_bytes('Hello World!')
173173

174174
# Mess up the signature by replacing a known byte
175-
jwt_sig = base64.b64decode(ensure_bytes(
175+
sig = base64.b64decode(ensure_bytes(
176176
'MIGIAkIB9vYz+inBL8aOTA4auYz/zVuig7TT1bQgKROIQX9YpViHkFa4DT5'
177177
'5FuFKn9XzVlk90p6ldEj42DC9YecXHbC2t+cCQgCicY+8f3f/KCNtWK7cif'
178178
'6vdsVwm6Lrjs0Ag6ZqCf+olN11hVt1qKBC4lXppqB1gNWEmNQaiz1z2QRyc'
179179
'zJ8hSJmbw=='.replace('r', 's')))
180180

181181
with open(key_path('testkey_ec.pub'), 'r') as keyfile:
182-
jwt_pub_key = algo.prepare_key(keyfile.read())
182+
pub_key = algo.prepare_key(keyfile.read())
183183

184-
result = algo.verify(jwt_message, jwt_pub_key, jwt_sig)
184+
result = algo.verify(message, pub_key, sig)
185185
self.assertFalse(result)
186186

187187
@unittest.skipIf(not has_crypto, 'Not supported without cryptography library')
188188
def test_ec_verify_should_return_true_if_signature_valid(self):
189189
algo = ECAlgorithm(ECAlgorithm.SHA256)
190190

191-
jwt_message = ensure_bytes('Hello World!')
191+
message = ensure_bytes('Hello World!')
192192

193-
jwt_sig = base64.b64decode(ensure_bytes(
193+
sig = base64.b64decode(ensure_bytes(
194194
'MIGIAkIB9vYz+inBL8aOTA4auYz/zVuig7TT1bQgKROIQX9YpViHkFa4DT5'
195195
'5FuFKn9XzVlk90p6ldEj42DC9YecXHbC2t+cCQgCicY+8f3f/KCNtWK7cif'
196196
'6vdsVwm6Lrjs0Ag6ZqCf+olN11hVt1qKBC4lXppqB1gNWEmNQaiz1z2QRyc'
197197
'zJ8hSJmbw=='))
198198

199199
with open(key_path('testkey_ec.pub'), 'r') as keyfile:
200+
pub_key = algo.prepare_key(keyfile.read())
201+
202+
result = algo.verify(message, pub_key, sig)
203+
self.assertTrue(result)
204+
205+
@unittest.skipIf(not has_crypto, 'Not supported without cryptography library')
206+
def test_rsa_pss_sign_then_verify_should_return_true(self):
207+
algo = RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256)
208+
209+
message = ensure_bytes('Hello World!')
210+
211+
with open(key_path('testkey_rsa'), 'r') as keyfile:
212+
priv_key = algo.prepare_key(keyfile.read())
213+
sig = algo.sign(message, priv_key)
214+
215+
with open(key_path('testkey_rsa.pub'), 'r') as keyfile:
216+
pub_key = algo.prepare_key(keyfile.read())
217+
218+
result = algo.verify(message, pub_key, sig)
219+
self.assertTrue(result)
220+
221+
@unittest.skipIf(not has_crypto, 'Not supported without cryptography library')
222+
def test_rsa_pss_verify_should_return_false_if_signature_invalid(self):
223+
algo = RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256)
224+
225+
jwt_message = ensure_bytes('Hello World!')
226+
227+
jwt_sig = base64.b64decode(ensure_bytes(
228+
'ywKAUGRIDC//6X+tjvZA96yEtMqpOrSppCNfYI7NKyon3P7doud5v65oWNu'
229+
'vQsz0fzPGfF7mQFGo9Cm9Vn0nljm4G6PtqZRbz5fXNQBH9k10gq34AtM02c'
230+
'/cveqACQ8gF3zxWh6qr9jVqIpeMEaEBIkvqG954E0HT9s9ybHShgHX9mlWk'
231+
'186/LopP4xe5c/hxOQjwhv6yDlTiwJFiqjNCvj0GyBKsc4iECLGIIO+4mC4'
232+
'daOCWqbpZDuLb1imKpmm8Nsm56kAxijMLZnpCcnPgyb7CqG+B93W9GHglA5'
233+
'drUeR1gRtO7vqbZMsCAQ4bpjXxwbYyjQlEVuMl73UL6sOWg=='))
234+
235+
jwt_sig += ensure_bytes('123') # Signature is now invalid
236+
237+
with open(key_path('testkey_rsa.pub'), 'r') as keyfile:
238+
jwt_pub_key = algo.prepare_key(keyfile.read())
239+
240+
result = algo.verify(jwt_message, jwt_pub_key, jwt_sig)
241+
self.assertFalse(result)
242+
243+
@unittest.skipIf(not has_crypto, 'Not supported without cryptography library')
244+
def test_rsa_pss_verify_should_return_true_if_signature_valid(self):
245+
algo = RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256)
246+
247+
jwt_message = ensure_bytes('Hello World!')
248+
249+
jwt_sig = base64.b64decode(ensure_bytes(
250+
'ywKAUGRIDC//6X+tjvZA96yEtMqpOrSppCNfYI7NKyon3P7doud5v65oWNu'
251+
'vQsz0fzPGfF7mQFGo9Cm9Vn0nljm4G6PtqZRbz5fXNQBH9k10gq34AtM02c'
252+
'/cveqACQ8gF3zxWh6qr9jVqIpeMEaEBIkvqG954E0HT9s9ybHShgHX9mlWk'
253+
'186/LopP4xe5c/hxOQjwhv6yDlTiwJFiqjNCvj0GyBKsc4iECLGIIO+4mC4'
254+
'daOCWqbpZDuLb1imKpmm8Nsm56kAxijMLZnpCcnPgyb7CqG+B93W9GHglA5'
255+
'drUeR1gRtO7vqbZMsCAQ4bpjXxwbYyjQlEVuMl73UL6sOWg=='))
256+
257+
with open(key_path('testkey_rsa.pub'), 'r') as keyfile:
200258
jwt_pub_key = algo.prepare_key(keyfile.read())
201259

202260
result = algo.verify(jwt_message, jwt_pub_key, jwt_sig)

tests/test_api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,10 +615,17 @@ def test_rsa_related_algorithms(self):
615615
self.assertTrue('RS256' in jwt_algorithms)
616616
self.assertTrue('RS384' in jwt_algorithms)
617617
self.assertTrue('RS512' in jwt_algorithms)
618+
self.assertTrue('PS256' in jwt_algorithms)
619+
self.assertTrue('PS384' in jwt_algorithms)
620+
self.assertTrue('PS512' in jwt_algorithms)
621+
618622
else:
619623
self.assertFalse('RS256' in jwt_algorithms)
620624
self.assertFalse('RS384' in jwt_algorithms)
621625
self.assertFalse('RS512' in jwt_algorithms)
626+
self.assertFalse('PS256' in jwt_algorithms)
627+
self.assertFalse('PS384' in jwt_algorithms)
628+
self.assertFalse('PS512' in jwt_algorithms)
622629

623630
@unittest.skipIf(not has_crypto, "Can't run without cryptography library")
624631
def test_encode_decode_with_ecdsa_sha256(self):

0 commit comments

Comments
 (0)