44
55
66from .compat import constant_time_compare , string_types
7- from .exceptions import InvalidKeyError
7+ from .exceptions import (InvalidAsymmetricKeyError , InvalidJwkError ,
8+ InvalidKeyError )
89from .utils import (
910 base64url_decode , base64url_encode , der_to_raw_signature ,
1011 force_bytes , force_unicode , from_base64url_uint , raw_to_der_signature ,
@@ -137,6 +138,9 @@ def __init__(self, hash_alg):
137138 self .hash_alg = hash_alg
138139
139140 def prepare_key (self , key ):
141+ if not isinstance (key , string_types ):
142+ raise InvalidKeyError ("HMAC secret key must be a string type." )
143+
140144 key = force_bytes (key )
141145
142146 invalid_strings = [
@@ -164,7 +168,7 @@ def from_jwk(jwk):
164168 obj = json .loads (jwk )
165169
166170 if obj .get ('kty' ) != 'oct' :
167- raise InvalidKeyError ('Not an HMAC key' )
171+ raise InvalidKeyError ('Invalid key: Not an HMAC key' )
168172
169173 return base64url_decode (obj ['k' ])
170174
@@ -194,20 +198,28 @@ def prepare_key(self, key):
194198 isinstance (key , RSAPublicKey ):
195199 return key
196200
197- if isinstance (key , string_types ):
198- key = force_bytes (key )
201+ if not isinstance (key , string_types ):
202+ raise InvalidAsymmetricKeyError
203+
204+ key = force_bytes (key )
199205
206+ if key .startswith (b'ssh-rsa' ):
200207 try :
201- if key .startswith (b'ssh-rsa' ):
202- key = load_ssh_public_key (key , backend = default_backend ())
203- else :
204- key = load_pem_private_key (key , password = None , backend = default_backend ())
208+ return load_ssh_public_key (key , backend = default_backend ())
205209 except ValueError :
206- key = load_pem_public_key (key , backend = default_backend ())
207- else :
208- raise TypeError ('Expecting a PEM-formatted key.' )
210+ raise InvalidAsymmetricKeyError
211+
212+ try :
213+ return load_pem_private_key (key , password = None , backend = default_backend ())
214+ except ValueError :
215+ pass
216+
217+ try :
218+ return load_pem_public_key (key , backend = default_backend ())
219+ except ValueError :
220+ pass
209221
210- return key
222+ raise InvalidAsymmetricKeyError
211223
212224 @staticmethod
213225 def to_jwk (key_obj ):
@@ -241,7 +253,7 @@ def to_jwk(key_obj):
241253 'e' : force_unicode (to_base64url_uint (numbers .e ))
242254 }
243255 else :
244- raise InvalidKeyError ('Not a public or private key ' )
256+ raise InvalidKeyError ('Invalid key: Expecting a RSAPublicKey or RSAPrivateKey instance. ' )
245257
246258 return json .dumps (obj )
247259
@@ -250,22 +262,22 @@ def from_jwk(jwk):
250262 try :
251263 obj = json .loads (jwk )
252264 except ValueError :
253- raise InvalidKeyError ('Key is not valid JSON' )
265+ raise InvalidJwkError ('Key is not valid JSON' )
254266
255267 if obj .get ('kty' ) != 'RSA' :
256- raise InvalidKeyError ('Not an RSA key' )
268+ raise InvalidJwkError ('Not an RSA key' )
257269
258270 if 'd' in obj and 'e' in obj and 'n' in obj :
259271 # Private key
260272 if 'oth' in obj :
261- raise InvalidKeyError ('Unsupported RSA private key: > 2 primes not supported' )
273+ raise InvalidJwkError ('Unsupported RSA private key: > 2 primes not supported' )
262274
263275 other_props = ['p' , 'q' , 'dp' , 'dq' , 'qi' ]
264276 props_found = [prop in obj for prop in other_props ]
265277 any_props_found = any (props_found )
266278
267279 if any_props_found and not all (props_found ):
268- raise InvalidKeyError ('RSA key must include all parameters if any are present besides d' )
280+ raise InvalidJwkError ('RSA key must include all parameters if any are present besides d' )
269281
270282 public_numbers = RSAPublicNumbers (
271283 from_base64url_uint (obj ['e' ]), from_base64url_uint (obj ['n' ])
@@ -306,7 +318,7 @@ def from_jwk(jwk):
306318
307319 return numbers .public_key (default_backend ())
308320 else :
309- raise InvalidKeyError ('Not a public or private key' )
321+ raise InvalidKeyError ('Not a valid JWK public or private key' )
310322
311323 def sign (self , msg , key ):
312324 signer = key .signer (
@@ -349,24 +361,28 @@ def prepare_key(self, key):
349361 isinstance (key , EllipticCurvePublicKey ):
350362 return key
351363
352- if isinstance (key , string_types ):
353- key = force_bytes (key )
364+ if not isinstance (key , string_types ):
365+ raise InvalidAsymmetricKeyError
366+
367+ key = force_bytes (key )
354368
355- # Attempt to load key. We don't know if it's
356- # a Signing Key or a Verifying Key, so we try
357- # the Verifying Key first.
369+ if key .startswith (b'ecdsa-sha2-' ):
358370 try :
359- if key .startswith (b'ecdsa-sha2-' ):
360- key = load_ssh_public_key (key , backend = default_backend ())
361- else :
362- key = load_pem_public_key (key , backend = default_backend ())
371+ return load_ssh_public_key (key , backend = default_backend ())
363372 except ValueError :
364- key = load_pem_private_key ( key , password = None , backend = default_backend ())
373+ raise InvalidAsymmetricKeyError
365374
366- else :
367- raise TypeError ('Expecting a PEM-formatted key.' )
375+ try :
376+ return load_pem_public_key (key , backend = default_backend ())
377+ except ValueError :
378+ pass
379+
380+ try :
381+ return load_pem_private_key (key , password = None , backend = default_backend ())
382+ except ValueError :
383+ pass
368384
369- return key
385+ raise InvalidAsymmetricKeyError
370386
371387 def sign (self , msg , key ):
372388 signer = key .signer (ec .ECDSA (self .hash_alg ()))
0 commit comments