Skip to content

Commit 6290a1c

Browse files
committed
Merge pull request jpadilla#71 from mark-adams/algorithm-registry
Fixes jpadilla#70. Implemented registry pattern for class-based algorithms
2 parents 0afba10 + eff1505 commit 6290a1c

File tree

5 files changed

+265
-249
lines changed

5 files changed

+265
-249
lines changed

jwt/__init__.py

Lines changed: 31 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,23 @@
55
http://self-issued.info/docs/draft-jones-json-web-token-01.html
66
"""
77

8-
import base64
98
import binascii
10-
import hashlib
11-
import hmac
12-
from datetime import datetime, timedelta
9+
1310
from calendar import timegm
1411
from collections import Mapping
12+
from datetime import datetime, timedelta
1513

16-
from .compat import (json, string_types, text_type, constant_time_compare,
17-
timedelta_total_seconds)
14+
from jwt.utils import base64url_decode, base64url_encode
15+
16+
from .compat import (json, string_types, text_type, timedelta_total_seconds)
1817

1918

2019
__version__ = '0.4.1'
2120
__all__ = [
2221
# Functions
2322
'encode',
2423
'decode',
24+
'register_algorithm',
2525

2626
# Exceptions
2727
'InvalidTokenError',
@@ -33,9 +33,25 @@
3333
# Deprecated aliases
3434
'ExpiredSignature',
3535
'InvalidAudience',
36-
'InvalidIssuer',
36+
'InvalidIssuer'
3737
]
3838

39+
_algorithms = {}
40+
41+
42+
def register_algorithm(alg_id, alg_obj):
43+
""" Registers a new Algorithm for use when creating and verifying JWTs """
44+
if alg_id in _algorithms:
45+
raise ValueError('Algorithm already has a handler.')
46+
47+
if not isinstance(alg_obj, Algorithm):
48+
raise TypeError('Object is not of type `Algorithm`')
49+
50+
_algorithms[alg_id] = alg_obj
51+
52+
from jwt.algorithms import Algorithm, _register_default_algorithms # NOQA
53+
_register_default_algorithms()
54+
3955

4056
class InvalidTokenError(Exception):
4157
pass
@@ -56,187 +72,11 @@ class InvalidAudienceError(InvalidTokenError):
5672
class InvalidIssuerError(InvalidTokenError):
5773
pass
5874

59-
6075
# Compatibility aliases (deprecated)
6176
ExpiredSignature = ExpiredSignatureError
6277
InvalidAudience = InvalidAudienceError
6378
InvalidIssuer = InvalidIssuerError
6479

65-
signing_methods = {
66-
'none': lambda msg, key: b'',
67-
'HS256': lambda msg, key: hmac.new(key, msg, hashlib.sha256).digest(),
68-
'HS384': lambda msg, key: hmac.new(key, msg, hashlib.sha384).digest(),
69-
'HS512': lambda msg, key: hmac.new(key, msg, hashlib.sha512).digest()
70-
}
71-
72-
verify_methods = {
73-
'HS256': lambda msg, key: hmac.new(key, msg, hashlib.sha256).digest(),
74-
'HS384': lambda msg, key: hmac.new(key, msg, hashlib.sha384).digest(),
75-
'HS512': lambda msg, key: hmac.new(key, msg, hashlib.sha512).digest()
76-
}
77-
78-
79-
def prepare_HS_key(key):
80-
if not isinstance(key, string_types) and not isinstance(key, bytes):
81-
raise TypeError('Expecting a string- or bytes-formatted key.')
82-
83-
if isinstance(key, text_type):
84-
key = key.encode('utf-8')
85-
86-
return key
87-
88-
prepare_key_methods = {
89-
'none': lambda key: None,
90-
'HS256': prepare_HS_key,
91-
'HS384': prepare_HS_key,
92-
'HS512': prepare_HS_key
93-
}
94-
95-
try:
96-
from cryptography.hazmat.primitives import interfaces, hashes
97-
from cryptography.hazmat.primitives.serialization import (
98-
load_pem_private_key, load_pem_public_key, load_ssh_public_key
99-
)
100-
from cryptography.hazmat.primitives.asymmetric import ec, padding
101-
from cryptography.hazmat.backends import default_backend
102-
from cryptography.exceptions import InvalidSignature
103-
104-
def sign_rsa(msg, key, hashalg):
105-
signer = key.signer(
106-
padding.PKCS1v15(),
107-
hashalg
108-
)
109-
110-
signer.update(msg)
111-
return signer.finalize()
112-
113-
def verify_rsa(msg, key, hashalg, sig):
114-
verifier = key.verifier(
115-
sig,
116-
padding.PKCS1v15(),
117-
hashalg
118-
)
119-
120-
verifier.update(msg)
121-
122-
try:
123-
verifier.verify()
124-
return True
125-
except InvalidSignature:
126-
return False
127-
128-
signing_methods.update({
129-
'RS256': lambda msg, key: sign_rsa(msg, key, hashes.SHA256()),
130-
'RS384': lambda msg, key: sign_rsa(msg, key, hashes.SHA384()),
131-
'RS512': lambda msg, key: sign_rsa(msg, key, hashes.SHA512())
132-
})
133-
134-
verify_methods.update({
135-
'RS256': lambda msg, key, sig: verify_rsa(msg, key, hashes.SHA256(), sig),
136-
'RS384': lambda msg, key, sig: verify_rsa(msg, key, hashes.SHA384(), sig),
137-
'RS512': lambda msg, key, sig: verify_rsa(msg, key, hashes.SHA512(), sig)
138-
})
139-
140-
def prepare_RS_key(key):
141-
if isinstance(key, interfaces.RSAPrivateKey) or \
142-
isinstance(key, interfaces.RSAPublicKey):
143-
return key
144-
145-
if isinstance(key, string_types):
146-
if isinstance(key, text_type):
147-
key = key.encode('utf-8')
148-
149-
try:
150-
if key.startswith(b'ssh-rsa'):
151-
key = load_ssh_public_key(key, backend=default_backend())
152-
else:
153-
key = load_pem_private_key(key, password=None, backend=default_backend())
154-
except ValueError:
155-
key = load_pem_public_key(key, backend=default_backend())
156-
else:
157-
raise TypeError('Expecting a PEM-formatted key.')
158-
159-
return key
160-
161-
prepare_key_methods.update({
162-
'RS256': prepare_RS_key,
163-
'RS384': prepare_RS_key,
164-
'RS512': prepare_RS_key
165-
})
166-
167-
def sign_ecdsa(msg, key, hashalg):
168-
signer = key.signer(ec.ECDSA(hashalg))
169-
170-
signer.update(msg)
171-
return signer.finalize()
172-
173-
def verify_ecdsa(msg, key, hashalg, sig):
174-
verifier = key.verifier(sig, ec.ECDSA(hashalg))
175-
176-
verifier.update(msg)
177-
178-
try:
179-
verifier.verify()
180-
return True
181-
except InvalidSignature:
182-
return False
183-
184-
signing_methods.update({
185-
'ES256': lambda msg, key: sign_ecdsa(msg, key, hashes.SHA256()),
186-
'ES384': lambda msg, key: sign_ecdsa(msg, key, hashes.SHA384()),
187-
'ES512': lambda msg, key: sign_ecdsa(msg, key, hashes.SHA512()),
188-
})
189-
190-
verify_methods.update({
191-
'ES256': lambda msg, key, sig: verify_ecdsa(msg, key, hashes.SHA256(), sig),
192-
'ES384': lambda msg, key, sig: verify_ecdsa(msg, key, hashes.SHA384(), sig),
193-
'ES512': lambda msg, key, sig: verify_ecdsa(msg, key, hashes.SHA512(), sig),
194-
})
195-
196-
def prepare_ES_key(key):
197-
if isinstance(key, interfaces.EllipticCurvePrivateKey) or \
198-
isinstance(key, interfaces.EllipticCurvePublicKey):
199-
return key
200-
201-
if isinstance(key, string_types):
202-
if isinstance(key, text_type):
203-
key = key.encode('utf-8')
204-
205-
# Attempt to load key. We don't know if it's
206-
# a Signing Key or a Verifying Key, so we try
207-
# the Verifying Key first.
208-
try:
209-
key = load_pem_public_key(key, backend=default_backend())
210-
except ValueError:
211-
key = load_pem_private_key(key, password=None, backend=default_backend())
212-
213-
else:
214-
raise TypeError('Expecting a PEM-formatted key.')
215-
216-
return key
217-
218-
prepare_key_methods.update({
219-
'ES256': prepare_ES_key,
220-
'ES384': prepare_ES_key,
221-
'ES512': prepare_ES_key
222-
})
223-
224-
except ImportError:
225-
pass
226-
227-
228-
def base64url_decode(input):
229-
rem = len(input) % 4
230-
231-
if rem > 0:
232-
input += b'=' * (4 - rem)
233-
234-
return base64.urlsafe_b64decode(input)
235-
236-
237-
def base64url_encode(input):
238-
return base64.urlsafe_b64encode(input).replace(b'=', b'')
239-
24080

24181
def header(jwt):
24282
if isinstance(jwt, text_type):
@@ -290,8 +130,10 @@ def encode(payload, key, algorithm='HS256', headers=None, json_encoder=None):
290130
# Segments
291131
signing_input = b'.'.join(segments)
292132
try:
293-
key = prepare_key_methods[algorithm](key)
294-
signature = signing_methods[algorithm](signing_input, key)
133+
alg_obj = _algorithms[algorithm]
134+
key = alg_obj.prepare_key(key)
135+
signature = alg_obj.sign(signing_input, key)
136+
295137
except KeyError:
296138
raise NotImplementedError('Algorithm not supported')
297139

@@ -360,17 +202,12 @@ def verify_signature(payload, signing_input, header, signature, key='',
360202
raise TypeError('audience must be a string or None')
361203

362204
try:
363-
algorithm = header['alg'].upper()
364-
key = prepare_key_methods[algorithm](key)
205+
alg_obj = _algorithms[header['alg'].upper()]
206+
key = alg_obj.prepare_key(key)
365207

366-
if algorithm.startswith('HS'):
367-
expected = verify_methods[algorithm](signing_input, key)
208+
if not alg_obj.verify(signing_input, key, signature):
209+
raise DecodeError('Signature verification failed')
368210

369-
if not constant_time_compare(signature, expected):
370-
raise DecodeError('Signature verification failed')
371-
else:
372-
if not verify_methods[algorithm](signing_input, key, signature):
373-
raise DecodeError('Signature verification failed')
374211
except KeyError:
375212
raise DecodeError('Algorithm not supported')
376213

0 commit comments

Comments
 (0)