Skip to content

Commit a4794be

Browse files
committed
Merge pull request jpadilla#110 from mark-adams/algo-whitelist
Added support for whitelist validation of the `alg` header
2 parents 838f824 + d4a0a22 commit a4794be

File tree

6 files changed

+119
-52
lines changed

6 files changed

+119
-52
lines changed

README.md

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,19 @@ $ pip install cryptography
2626

2727
```python
2828
import jwt
29-
jwt.encode({'some': 'payload'}, 'secret')
29+
jwt.encode({'some': 'payload'}, 'secret', algorithm='HS256')
3030
```
3131

3232
Additional headers may also be specified.
3333

3434
```python
35-
jwt.encode({'some': 'payload'}, 'secret', headers={'kid': '230498151c214b788dd97f22b85410a5'})
35+
jwt.encode({'some': 'payload'}, 'secret', algorithm='HS256', headers={'kid': '230498151c214b788dd97f22b85410a5'})
3636
```
3737

3838
Note the resulting JWT will not be encrypted, but verifiable with a secret key.
3939

4040
```python
41-
jwt.decode('someJWTstring', 'secret')
41+
jwt.decode('someJWTstring', 'secret', algorithms=['HS256'])
4242
```
4343

4444
If the secret is wrong, it will raise a `jwt.DecodeError` telling you as such.
@@ -83,12 +83,27 @@ currently supports:
8383
* RS384 - RSASSA-PKCS1-v1_5 signature algorithm using SHA-384 hash algorithm
8484
* RS512 - RSASSA-PKCS1-v1_5 signature algorithm using SHA-512 hash algorithm
8585

86-
Change the algorithm with by setting it in encode:
86+
### Encoding
87+
You can specify which algorithm you would like to use to sign the JWT
88+
by using the `algorithm` parameter:
8789

8890
```python
89-
jwt.encode({'some': 'payload'}, 'secret', 'HS512')
91+
jwt.encode({'some': 'payload'}, 'secret', algorithm='HS512')
9092
```
9193

94+
### Decoding
95+
When decoding, you can specify which algorithms you would like to permit
96+
when validating the JWT by using the `algorithms` parameter which takes a list
97+
of allowed algorithms:
98+
99+
```python
100+
jwt.decode(some_jwt, 'secret', algorithms=['HS512', 'HS256'])
101+
```
102+
103+
In the above case, if the JWT has any value for its alg header other than
104+
HS512 or HS256, the claim will be rejected with an `InvalidAlgorithmError`.
105+
106+
### Asymmetric (Public-key) Algorithms
92107
Usage of RSA (RS\*) and EC (EC\*) algorithms require a basic understanding
93108
of how public-key cryptography is used with regards to digital signatures.
94109
If you are unfamiliar, you may want to read
@@ -103,6 +118,7 @@ When using the ECDSA algorithms, the `key` argument is expected to
103118
be an Elliptic Curve public or private key in PEM format. The type of key
104119
(private or public) depends on whether you are signing or verifying.
105120

121+
106122
## Support of registered claim names
107123

108124
JSON Web Token defines some registered claim names and defines how they should

jwt/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
__copyright__ = 'Copyright 2015 José Padilla'
1717

1818

19-
from .api import encode, decode, register_algorithm, PyJWT
19+
from .api import (
20+
encode, decode, register_algorithm, unregister_algorithm, PyJWT
21+
)
2022
from .exceptions import (
2123
InvalidTokenError, DecodeError, ExpiredSignatureError,
2224
InvalidAudienceError, InvalidIssuerError,

jwt/algorithms.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,28 @@
1818
has_crypto = False
1919

2020

21-
def _register_default_algorithms(pyjwt_obj):
21+
def get_default_algorithms():
2222
"""
23-
Registers the algorithms that are implemented by the library.
23+
Returns the algorithms that are implemented by the library.
2424
"""
25-
pyjwt_obj.register_algorithm('none', NoneAlgorithm())
26-
pyjwt_obj.register_algorithm('HS256', HMACAlgorithm(HMACAlgorithm.SHA256))
27-
pyjwt_obj.register_algorithm('HS384', HMACAlgorithm(HMACAlgorithm.SHA384))
28-
pyjwt_obj.register_algorithm('HS512', HMACAlgorithm(HMACAlgorithm.SHA512))
25+
default_algorithms = {
26+
'none': NoneAlgorithm(),
27+
'HS256': HMACAlgorithm(HMACAlgorithm.SHA256),
28+
'HS384': HMACAlgorithm(HMACAlgorithm.SHA384),
29+
'HS512': HMACAlgorithm(HMACAlgorithm.SHA512)
30+
}
2931

3032
if has_crypto:
31-
pyjwt_obj.register_algorithm('RS256', RSAAlgorithm(RSAAlgorithm.SHA256))
32-
pyjwt_obj.register_algorithm('RS384', RSAAlgorithm(RSAAlgorithm.SHA384))
33-
pyjwt_obj.register_algorithm('RS512', RSAAlgorithm(RSAAlgorithm.SHA512))
34-
35-
pyjwt_obj.register_algorithm('ES256', ECAlgorithm(ECAlgorithm.SHA256))
36-
pyjwt_obj.register_algorithm('ES384', ECAlgorithm(ECAlgorithm.SHA384))
37-
pyjwt_obj.register_algorithm('ES512', ECAlgorithm(ECAlgorithm.SHA512))
33+
default_algorithms.update({
34+
'RS256': RSAAlgorithm(RSAAlgorithm.SHA256),
35+
'RS384': RSAAlgorithm(RSAAlgorithm.SHA384),
36+
'RS512': RSAAlgorithm(RSAAlgorithm.SHA512),
37+
'ES256': ECAlgorithm(ECAlgorithm.SHA256),
38+
'ES384': ECAlgorithm(ECAlgorithm.SHA384),
39+
'ES512': ECAlgorithm(ECAlgorithm.SHA512)
40+
})
41+
42+
return default_algorithms
3843

3944

4045
class Algorithm(object):

jwt/api.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,24 @@
55
from collections import Mapping
66
from datetime import datetime, timedelta
77

8-
from .algorithms import Algorithm, _register_default_algorithms # NOQA
8+
from .algorithms import Algorithm, get_default_algorithms # NOQA
99
from .compat import string_types, text_type, timedelta_total_seconds
1010
from .exceptions import (
11-
DecodeError, ExpiredSignatureError,
11+
DecodeError, ExpiredSignatureError, InvalidAlgorithmError,
1212
InvalidAudienceError, InvalidIssuerError
1313
)
1414
from .utils import base64url_decode, base64url_encode
1515

1616

1717
class PyJWT(object):
1818
def __init__(self, algorithms=None):
19-
self._algorithms = {}
19+
self._algorithms = get_default_algorithms()
20+
self._valid_algs = set(algorithms) if algorithms is not None else set(self._algorithms)
2021

21-
if algorithms is None:
22-
_register_default_algorithms(self)
23-
else:
24-
for key, algo in algorithms.items():
25-
self.register_algorithm(key, algo)
22+
# Remove algorithms that aren't on the whitelist
23+
for key in list(self._algorithms.keys()):
24+
if key not in self._valid_algs:
25+
del self._algorithms[key]
2626

2727
def register_algorithm(self, alg_id, alg_obj):
2828
"""
@@ -35,19 +35,34 @@ def register_algorithm(self, alg_id, alg_obj):
3535
raise TypeError('Object is not of type `Algorithm`')
3636

3737
self._algorithms[alg_id] = alg_obj
38+
self._valid_algs.add(alg_id)
3839

39-
def get_supported_algorithms(self):
40+
def unregister_algorithm(self, alg_id):
41+
"""
42+
Unregisters an Algorithm for use when creating and verifying tokens
43+
Throws KeyError if algorithm is not registered.
44+
"""
45+
if alg_id not in self._algorithms:
46+
raise KeyError('The specified algorithm could not be removed because it is not registered.')
47+
48+
del self._algorithms[alg_id]
49+
self._valid_algs.remove(alg_id)
50+
51+
def get_algorithms(self):
4052
"""
4153
Returns a list of supported values for the 'alg' parameter.
4254
"""
43-
return self._algorithms.keys()
55+
return list(self._valid_algs)
4456

4557
def encode(self, payload, key, algorithm='HS256', headers=None, json_encoder=None):
4658
segments = []
4759

4860
if algorithm is None:
4961
algorithm = 'none'
5062

63+
if algorithm not in self._valid_algs:
64+
pass
65+
5166
# Check that we get a mapping
5267
if not isinstance(payload, Mapping):
5368
raise TypeError('Expecting a mapping object, as json web token only'
@@ -94,12 +109,12 @@ def encode(self, payload, key, algorithm='HS256', headers=None, json_encoder=Non
94109

95110
return b'.'.join(segments)
96111

97-
def decode(self, jwt, key='', verify=True, **kwargs):
112+
def decode(self, jwt, key='', verify=True, algorithms=None, **kwargs):
98113
payload, signing_input, header, signature = self._load(jwt)
99114

100115
if verify:
101116
self._verify_signature(payload, signing_input, header, signature,
102-
key, **kwargs)
117+
key, algorithms, **kwargs)
103118

104119
return payload
105120

@@ -142,24 +157,29 @@ def _load(self, jwt):
142157
return (payload, signing_input, header, signature)
143158

144159
def _verify_signature(self, payload, signing_input, header, signature,
145-
key='', verify_expiration=True, leeway=0,
160+
key='', algorithms=None, verify_expiration=True, leeway=0,
146161
audience=None, issuer=None):
147162

163+
alg = header['alg']
164+
165+
if algorithms is not None and alg not in algorithms:
166+
raise InvalidAlgorithmError('The specified alg value is not allowed')
167+
148168
if isinstance(leeway, timedelta):
149169
leeway = timedelta_total_seconds(leeway)
150170

151171
if not isinstance(audience, (string_types, type(None))):
152172
raise TypeError('audience must be a string or None')
153173

154174
try:
155-
alg_obj = self._algorithms[header['alg']]
175+
alg_obj = self._algorithms[alg]
156176
key = alg_obj.prepare_key(key)
157177

158178
if not alg_obj.verify(signing_input, key, signature):
159179
raise DecodeError('Signature verification failed')
160180

161181
except KeyError:
162-
raise DecodeError('Algorithm not supported')
182+
raise InvalidAlgorithmError('Algorithm not supported')
163183

164184
if 'nbf' in payload and verify_expiration:
165185
utc_timestamp = timegm(datetime.utcnow().utctimetuple())
@@ -196,3 +216,4 @@ def _verify_signature(self, payload, signing_input, header, signature,
196216
encode = _jwt_global_obj.encode
197217
decode = _jwt_global_obj.decode
198218
register_algorithm = _jwt_global_obj.register_algorithm
219+
unregister_algorithm = _jwt_global_obj.unregister_algorithm

jwt/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ class InvalidKeyError(Exception):
2222
pass
2323

2424

25+
class InvalidAlgorithmError(InvalidTokenError):
26+
pass
27+
28+
2529
# Compatibility aliases (deprecated)
2630
ExpiredSignature = ExpiredSignatureError
2731
InvalidAudience = InvalidAudienceError

tests/test_api.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
from jwt.algorithms import Algorithm
1010
from jwt.api import PyJWT
11-
from jwt.exceptions import DecodeError, ExpiredSignatureError, InvalidAudienceError, InvalidIssuerError
11+
from jwt.exceptions import (
12+
DecodeError, ExpiredSignatureError, InvalidAlgorithmError,
13+
InvalidAudienceError, InvalidIssuerError
14+
)
1215

1316
from .compat import text_type, unittest
1417
from .utils import ensure_bytes
@@ -35,20 +38,6 @@ def setUp(self): # noqa
3538
'claim': 'insanity'}
3639
self.jwt = PyJWT()
3740

38-
def test_if_algorithms_param_is_empty_dict_then_no_algorithms(self):
39-
jwt_obj = PyJWT(algorithms={})
40-
jwt_algorithms = jwt_obj.get_supported_algorithms()
41-
42-
self.assertEqual(len(jwt_algorithms), 0)
43-
44-
def test_algorithms_param_sets_algorithms(self):
45-
algorithms = {'TESTALG': Algorithm()}
46-
jwt_obj = PyJWT(algorithms=algorithms)
47-
48-
supported_algs = jwt_obj.get_supported_algorithms()
49-
self.assertEqual(len(supported_algs), 1)
50-
self.assertIn('TESTALG', supported_algs)
51-
5241
def test_register_algorithm_does_not_allow_duplicate_registration(self):
5342
self.jwt.register_algorithm('AAA', Algorithm())
5443

@@ -59,13 +48,43 @@ def test_register_algorithm_rejects_non_algorithm_obj(self):
5948
with self.assertRaises(TypeError):
6049
self.jwt.register_algorithm('AAA123', {})
6150

51+
def test_unregister_algorithm_removes_algorithm(self):
52+
supported = self.jwt.get_algorithms()
53+
self.assertIn('none', supported)
54+
self.assertIn('HS256', supported)
55+
56+
self.jwt.unregister_algorithm('HS256')
57+
58+
supported = self.jwt.get_algorithms()
59+
self.assertNotIn('HS256', supported)
60+
61+
def test_unregister_algorithm_throws_error_if_not_registered(self):
62+
with self.assertRaises(KeyError):
63+
self.jwt.unregister_algorithm('AAA')
64+
65+
def test_algorithms_parameter_removes_alg_from_algorithms_list(self):
66+
self.assertIn('none', self.jwt.get_algorithms())
67+
self.assertIn('HS256', self.jwt.get_algorithms())
68+
69+
self.jwt = PyJWT(algorithms=['HS256'])
70+
self.assertNotIn('none', self.jwt.get_algorithms())
71+
self.assertIn('HS256', self.jwt.get_algorithms())
72+
6273
def test_encode_decode(self):
6374
secret = 'secret'
6475
jwt_message = self.jwt.encode(self.payload, secret)
6576
decoded_payload = self.jwt.decode(jwt_message, secret)
6677

6778
self.assertEqual(decoded_payload, self.payload)
6879

80+
def test_decode_fails_when_alg_is_not_on_method_algorithms_param(self):
81+
secret = 'secret'
82+
jwt_token = self.jwt.encode(self.payload, secret, algorithm='HS256')
83+
self.jwt.decode(jwt_token, secret)
84+
85+
with self.assertRaises(InvalidAlgorithmError):
86+
self.jwt.decode(jwt_token, secret, algorithms=['HS384'])
87+
6988
def test_decode_works_with_unicode_token(self):
7089
secret = 'secret'
7190
unicode_jwt = text_type(
@@ -170,7 +189,7 @@ def test_decode_algorithm_param_should_be_case_sensitive(self):
170189
'.eyJoZWxsbyI6IndvcmxkIn0'
171190
'.5R_FEPE7SW2dT9GgIxPgZATjFGXfUDOSwo7TtO_Kd_g')
172191

173-
with self.assertRaises(DecodeError) as context:
192+
with self.assertRaises(InvalidAlgorithmError) as context:
174193
self.jwt.decode(example_jwt, 'secret')
175194

176195
exception = context.exception
@@ -670,7 +689,7 @@ def test_encode_decode_with_rsa_sha512(self):
670689

671690
def test_rsa_related_algorithms(self):
672691
self.jwt = PyJWT()
673-
jwt_algorithms = self.jwt.get_supported_algorithms()
692+
jwt_algorithms = self.jwt.get_algorithms()
674693

675694
if has_crypto:
676695
self.assertTrue('RS256' in jwt_algorithms)
@@ -773,7 +792,7 @@ def test_encode_decode_with_ecdsa_sha512(self):
773792

774793
def test_ecdsa_related_algorithms(self):
775794
self.jwt = PyJWT()
776-
jwt_algorithms = self.jwt.get_supported_algorithms()
795+
jwt_algorithms = self.jwt.get_algorithms()
777796

778797
if has_crypto:
779798
self.assertTrue('ES256' in jwt_algorithms)

0 commit comments

Comments
 (0)