Skip to content

Commit 877b257

Browse files
authored
Merge pull request jpadilla#245 from jpadilla/fix-key-errors
Refactor error handling in Algorithm.prepare_key() methods
2 parents 1710c15 + d04339d commit 877b257

File tree

7 files changed

+68
-40
lines changed

7 files changed

+68
-40
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
88
-------------------------------------------------------------------------
99
### Changed
1010
- Add support for ECDSA public keys in RFC 4253 (OpenSSH) format [#244][244]
11+
- All Algorithm.prepare_key() calls now return either a valid key value or raise InvalidKeyError
1112
- Renamed commandline script `jwt` to `jwt-cli` to avoid issues with the script clobbering the `jwt` module in some circumstances.
1213
- Better error messages when using an algorithm that requires the cryptography package, but it isn't available [#230][230]
1314

jwt/algorithms.py

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55

66
from .compat import constant_time_compare, string_types
7-
from .exceptions import InvalidKeyError
7+
from .exceptions import (InvalidAsymmetricKeyError, InvalidJwkError,
8+
InvalidKeyError)
89
from .utils import (
910
base64url_decode, base64url_encode, der_to_raw_signature,
1011
force_bytes, force_unicode, from_base64url_uint, raw_to_der_signature,
@@ -137,6 +138,9 @@ def __init__(self, hash_alg):
137138
self.hash_alg = hash_alg
138139

139140
def prepare_key(self, key):
141+
if not isinstance(key, string_types):
142+
raise InvalidKeyError("HMAC secret key must be a string type.")
143+
140144
key = force_bytes(key)
141145

142146
invalid_strings = [
@@ -164,7 +168,7 @@ def from_jwk(jwk):
164168
obj = json.loads(jwk)
165169

166170
if obj.get('kty') != 'oct':
167-
raise InvalidKeyError('Not an HMAC key')
171+
raise InvalidKeyError('Invalid key: Not an HMAC key')
168172

169173
return base64url_decode(obj['k'])
170174

@@ -194,20 +198,28 @@ def prepare_key(self, key):
194198
isinstance(key, RSAPublicKey):
195199
return key
196200

197-
if isinstance(key, string_types):
198-
key = force_bytes(key)
201+
if not isinstance(key, string_types):
202+
raise InvalidAsymmetricKeyError
203+
204+
key = force_bytes(key)
199205

206+
if key.startswith(b'ssh-rsa'):
200207
try:
201-
if key.startswith(b'ssh-rsa'):
202-
key = load_ssh_public_key(key, backend=default_backend())
203-
else:
204-
key = load_pem_private_key(key, password=None, backend=default_backend())
208+
return load_ssh_public_key(key, backend=default_backend())
205209
except ValueError:
206-
key = load_pem_public_key(key, backend=default_backend())
207-
else:
208-
raise TypeError('Expecting a PEM-formatted key.')
210+
raise InvalidAsymmetricKeyError
211+
212+
try:
213+
return load_pem_private_key(key, password=None, backend=default_backend())
214+
except ValueError:
215+
pass
216+
217+
try:
218+
return load_pem_public_key(key, backend=default_backend())
219+
except ValueError:
220+
pass
209221

210-
return key
222+
raise InvalidAsymmetricKeyError
211223

212224
@staticmethod
213225
def to_jwk(key_obj):
@@ -241,7 +253,7 @@ def to_jwk(key_obj):
241253
'e': force_unicode(to_base64url_uint(numbers.e))
242254
}
243255
else:
244-
raise InvalidKeyError('Not a public or private key')
256+
raise InvalidKeyError('Invalid key: Expecting a RSAPublicKey or RSAPrivateKey instance.')
245257

246258
return json.dumps(obj)
247259

@@ -250,22 +262,22 @@ def from_jwk(jwk):
250262
try:
251263
obj = json.loads(jwk)
252264
except ValueError:
253-
raise InvalidKeyError('Key is not valid JSON')
265+
raise InvalidJwkError('Key is not valid JSON')
254266

255267
if obj.get('kty') != 'RSA':
256-
raise InvalidKeyError('Not an RSA key')
268+
raise InvalidJwkError('Not an RSA key')
257269

258270
if 'd' in obj and 'e' in obj and 'n' in obj:
259271
# Private key
260272
if 'oth' in obj:
261-
raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported')
273+
raise InvalidJwkError('Unsupported RSA private key: > 2 primes not supported')
262274

263275
other_props = ['p', 'q', 'dp', 'dq', 'qi']
264276
props_found = [prop in obj for prop in other_props]
265277
any_props_found = any(props_found)
266278

267279
if any_props_found and not all(props_found):
268-
raise InvalidKeyError('RSA key must include all parameters if any are present besides d')
280+
raise InvalidJwkError('RSA key must include all parameters if any are present besides d')
269281

270282
public_numbers = RSAPublicNumbers(
271283
from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
@@ -306,7 +318,7 @@ def from_jwk(jwk):
306318

307319
return numbers.public_key(default_backend())
308320
else:
309-
raise InvalidKeyError('Not a public or private key')
321+
raise InvalidKeyError('Not a valid JWK public or private key')
310322

311323
def sign(self, msg, key):
312324
signer = key.signer(
@@ -349,24 +361,28 @@ def prepare_key(self, key):
349361
isinstance(key, EllipticCurvePublicKey):
350362
return key
351363

352-
if isinstance(key, string_types):
353-
key = force_bytes(key)
364+
if not isinstance(key, string_types):
365+
raise InvalidAsymmetricKeyError
366+
367+
key = force_bytes(key)
354368

355-
# Attempt to load key. We don't know if it's
356-
# a Signing Key or a Verifying Key, so we try
357-
# the Verifying Key first.
369+
if key.startswith(b'ecdsa-sha2-'):
358370
try:
359-
if key.startswith(b'ecdsa-sha2-'):
360-
key = load_ssh_public_key(key, backend=default_backend())
361-
else:
362-
key = load_pem_public_key(key, backend=default_backend())
371+
return load_ssh_public_key(key, backend=default_backend())
363372
except ValueError:
364-
key = load_pem_private_key(key, password=None, backend=default_backend())
373+
raise InvalidAsymmetricKeyError
365374

366-
else:
367-
raise TypeError('Expecting a PEM-formatted key.')
375+
try:
376+
return load_pem_public_key(key, backend=default_backend())
377+
except ValueError:
378+
pass
379+
380+
try:
381+
return load_pem_private_key(key, password=None, backend=default_backend())
382+
except ValueError:
383+
pass
368384

369-
return key
385+
raise InvalidAsymmetricKeyError
370386

371387
def sign(self, msg, key):
372388
signer = key.signer(ec.ECDSA(self.hash_alg()))

jwt/contrib/algorithms/py_ecdsa.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from jwt.algorithms import Algorithm
99
from jwt.compat import string_types, text_type
10+
from jwt.exceptions import InvalidAsymmetricKeyError
1011

1112

1213
class ECAlgorithm(Algorithm):
@@ -44,7 +45,7 @@ def prepare_key(self, key):
4445
key = ecdsa.SigningKey.from_pem(key)
4546

4647
else:
47-
raise TypeError('Expecting a PEM-formatted key.')
48+
raise InvalidAsymmetricKeyError('Expecting a PEM-formatted key.')
4849

4950
return key
5051

jwt/contrib/algorithms/pycrypto.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from jwt.algorithms import Algorithm
99
from jwt.compat import string_types, text_type
10+
from jwt.exceptions import InvalidAsymmetricKeyError
1011

1112

1213
class RSAAlgorithm(Algorithm):
@@ -36,7 +37,7 @@ def prepare_key(self, key):
3637

3738
key = RSA.importKey(key)
3839
else:
39-
raise TypeError('Expecting a PEM- or RSA-formatted key.')
40+
raise InvalidAsymmetricKeyError
4041

4142
return key
4243

jwt/exceptions.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,18 @@ class ImmatureSignatureError(InvalidTokenError):
2626
pass
2727

2828

29-
class InvalidKeyError(Exception):
29+
class InvalidKeyError(ValueError):
3030
pass
3131

3232

33+
class InvalidAsymmetricKeyError(InvalidKeyError):
34+
message = 'Invalid key: Keys must be in PEM or RFC 4253 format.'
35+
36+
37+
class InvalidJwkError(InvalidKeyError):
38+
message = 'Invalid key: Keys must be in JWK format.'
39+
40+
3341
class InvalidAlgorithmError(InvalidTokenError):
3442
pass
3543

tests/contrib/test_algorithms.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import base64
22

3+
from jwt.exceptions import InvalidAsymmetricKeyError
34
from jwt.utils import force_bytes, force_unicode
45

56
import pytest
@@ -36,7 +37,7 @@ def test_rsa_should_accept_unicode_key(self):
3637
def test_rsa_should_reject_non_string_key(self):
3738
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
3839

39-
with pytest.raises(TypeError):
40+
with pytest.raises(InvalidAsymmetricKeyError):
4041
algo.prepare_key(None)
4142

4243
def test_rsa_sign_should_generate_correct_signature_value(self):
@@ -117,7 +118,7 @@ class TestEcdsaAlgorithms:
117118
def test_ec_should_reject_non_string_key(self):
118119
algo = ECAlgorithm(ECAlgorithm.SHA256)
119120

120-
with pytest.raises(TypeError):
121+
with pytest.raises(InvalidAsymmetricKeyError):
121122
algo.prepare_key(None)
122123

123124
def test_ec_should_accept_unicode_key(self):

tests/test_algorithms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@ def test_none_algorithm_should_throw_exception_if_key_is_not_none(self):
5858
def test_hmac_should_reject_nonstring_key(self):
5959
algo = HMACAlgorithm(HMACAlgorithm.SHA256)
6060

61-
with pytest.raises(TypeError) as context:
61+
with pytest.raises(InvalidKeyError) as context:
6262
algo.prepare_key(object())
6363

6464
exception = context.value
65-
assert str(exception) == 'Expected a string value'
65+
assert str(exception) == 'HMAC secret key must be a string type.'
6666

6767
def test_hmac_should_accept_unicode_key(self):
6868
algo = HMACAlgorithm(HMACAlgorithm.SHA256)
@@ -144,7 +144,7 @@ def test_rsa_should_accept_unicode_key(self):
144144
def test_rsa_should_reject_non_string_key(self):
145145
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
146146

147-
with pytest.raises(TypeError):
147+
with pytest.raises(InvalidKeyError):
148148
algo.prepare_key(None)
149149

150150
@pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
@@ -358,7 +358,7 @@ def test_rsa_from_jwk_raises_exception_on_invalid_key(self):
358358
def test_ec_should_reject_non_string_key(self):
359359
algo = ECAlgorithm(ECAlgorithm.SHA256)
360360

361-
with pytest.raises(TypeError):
361+
with pytest.raises(InvalidKeyError):
362362
algo.prepare_key(None)
363363

364364
@pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')

0 commit comments

Comments
 (0)