Skip to content

Commit 4e02432

Browse files
committed
Refactored api.py so that all JWT functions are now part of a PyJWT class.
- Created a singleton instance to preserve jwt.encode, jwt.decode, jwt.register_algorithms existing public APIs - Renamed load and verify_signature to _load and _verify_signature since they are not part of the existing public API - Modified related tests to use PyJWT._load and PyJWT._verify_signature
1 parent d471631 commit 4e02432

File tree

4 files changed

+359
-351
lines changed

4 files changed

+359
-351
lines changed

jwt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
__copyright__ = 'Copyright 2015 José Padilla'
1717

1818

19-
from .api import encode, decode, register_algorithm
19+
from .api import encode, decode, register_algorithm, PyJWT
2020
from .exceptions import (
2121
InvalidTokenError, DecodeError, ExpiredSignatureError,
2222
InvalidAudienceError, InvalidIssuerError,

jwt/algorithms.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import hashlib
22
import hmac
33

4-
from .api import register_algorithm
54
from .compat import constant_time_compare, string_types, text_type
65

76
try:
@@ -18,23 +17,23 @@
1817
has_crypto = False
1918

2019

21-
def _register_default_algorithms():
20+
def _register_default_algorithms(pyjwt_obj):
2221
"""
2322
Registers the algorithms that are implemented by the library.
2423
"""
25-
register_algorithm('none', NoneAlgorithm())
26-
register_algorithm('HS256', HMACAlgorithm(HMACAlgorithm.SHA256))
27-
register_algorithm('HS384', HMACAlgorithm(HMACAlgorithm.SHA384))
28-
register_algorithm('HS512', HMACAlgorithm(HMACAlgorithm.SHA512))
24+
pyjwt_obj.register_algorithm('none', NoneAlgorithm())
25+
pyjwt_obj.register_algorithm('HS256', HMACAlgorithm(HMACAlgorithm.SHA256))
26+
pyjwt_obj.register_algorithm('HS384', HMACAlgorithm(HMACAlgorithm.SHA384))
27+
pyjwt_obj.register_algorithm('HS512', HMACAlgorithm(HMACAlgorithm.SHA512))
2928

3029
if has_crypto:
31-
register_algorithm('RS256', RSAAlgorithm(RSAAlgorithm.SHA256))
32-
register_algorithm('RS384', RSAAlgorithm(RSAAlgorithm.SHA384))
33-
register_algorithm('RS512', RSAAlgorithm(RSAAlgorithm.SHA512))
30+
pyjwt_obj.register_algorithm('RS256', RSAAlgorithm(RSAAlgorithm.SHA256()))
31+
pyjwt_obj.register_algorithm('RS384', RSAAlgorithm(RSAAlgorithm.SHA384()))
32+
pyjwt_obj.register_algorithm('RS512', RSAAlgorithm(RSAAlgorithm.SHA512()))
3433

35-
register_algorithm('ES256', ECAlgorithm(ECAlgorithm.SHA256))
36-
register_algorithm('ES384', ECAlgorithm(ECAlgorithm.SHA384))
37-
register_algorithm('ES512', ECAlgorithm(ECAlgorithm.SHA512))
34+
pyjwt_obj.register_algorithm('ES256', ECAlgorithm(ECAlgorithm.SHA256()))
35+
pyjwt_obj.register_algorithm('ES384', ECAlgorithm(ECAlgorithm.SHA384()))
36+
pyjwt_obj.register_algorithm('ES512', ECAlgorithm(ECAlgorithm.SHA512()))
3837

3938

4039
class Algorithm(object):

jwt/api.py

Lines changed: 168 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -12,175 +12,179 @@
1212
)
1313
from .utils import base64url_decode, base64url_encode
1414

15-
16-
_algorithms = {}
17-
18-
19-
def register_algorithm(alg_id, alg_obj):
20-
"""
21-
Registers a new Algorithm for use when creating and verifying tokens.
22-
"""
23-
if alg_id in _algorithms:
24-
raise ValueError('Algorithm already has a handler.')
25-
26-
if not isinstance(alg_obj, Algorithm):
27-
raise TypeError('Object is not of type `Algorithm`')
28-
29-
_algorithms[alg_id] = alg_obj
30-
3115
from jwt.algorithms import Algorithm, _register_default_algorithms # NOQA
32-
_register_default_algorithms()
33-
34-
35-
def encode(payload, key, algorithm='HS256', headers=None, json_encoder=None):
36-
segments = []
37-
38-
if algorithm is None:
39-
algorithm = 'none'
40-
41-
# Check that we get a mapping
42-
if not isinstance(payload, Mapping):
43-
raise TypeError('Expecting a mapping object, as json web token only'
44-
'support json objects.')
4516

46-
# Header
47-
header = {'typ': 'JWT', 'alg': algorithm}
48-
if headers:
49-
header.update(headers)
17+
class PyJWT(object):
18+
def __init__(self):
19+
self._algorithms = {}
20+
_register_default_algorithms(self)
21+
22+
def register_algorithm(self, alg_id, alg_obj):
23+
"""
24+
Registers a new Algorithm for use when creating and verifying tokens.
25+
"""
26+
if alg_id in self._algorithms:
27+
raise ValueError('Algorithm already has a handler.')
28+
29+
if not isinstance(alg_obj, Algorithm):
30+
raise TypeError('Object is not of type `Algorithm`')
31+
32+
self._algorithms[alg_id] = alg_obj
33+
34+
def encode(self, payload, key, algorithm='HS256', headers=None, json_encoder=None):
35+
segments = []
36+
37+
if algorithm is None:
38+
algorithm = 'none'
39+
40+
# Check that we get a mapping
41+
if not isinstance(payload, Mapping):
42+
raise TypeError('Expecting a mapping object, as json web token only'
43+
'support json objects.')
44+
45+
# Header
46+
header = {'typ': 'JWT', 'alg': algorithm}
47+
if headers:
48+
header.update(headers)
5049

51-
json_header = json.dumps(
52-
header,
53-
separators=(',', ':'),
54-
cls=json_encoder
55-
).encode('utf-8')
50+
json_header = json.dumps(
51+
header,
52+
separators=(',', ':'),
53+
cls=json_encoder
54+
).encode('utf-8')
5655

57-
segments.append(base64url_encode(json_header))
56+
segments.append(base64url_encode(json_header))
5857

59-
# Payload
60-
for time_claim in ['exp', 'iat', 'nbf']:
61-
# Convert datetime to a intDate value in known time-format claims
62-
if isinstance(payload.get(time_claim), datetime):
63-
payload[time_claim] = timegm(payload[time_claim].utctimetuple())
58+
# Payload
59+
for time_claim in ['exp', 'iat', 'nbf']:
60+
# Convert datetime to a intDate value in known time-format claims
61+
if isinstance(payload.get(time_claim), datetime):
62+
payload[time_claim] = timegm(payload[time_claim].utctimetuple())
6463

65-
json_payload = json.dumps(
66-
payload,
67-
separators=(',', ':'),
68-
cls=json_encoder
69-
).encode('utf-8')
64+
json_payload = json.dumps(
65+
payload,
66+
separators=(',', ':'),
67+
cls=json_encoder
68+
).encode('utf-8')
7069

71-
segments.append(base64url_encode(json_payload))
70+
segments.append(base64url_encode(json_payload))
7271

73-
# Segments
74-
signing_input = b'.'.join(segments)
75-
try:
76-
alg_obj = _algorithms[algorithm]
77-
key = alg_obj.prepare_key(key)
78-
signature = alg_obj.sign(signing_input, key)
79-
80-
except KeyError:
81-
raise NotImplementedError('Algorithm not supported')
82-
83-
segments.append(base64url_encode(signature))
84-
85-
return b'.'.join(segments)
86-
87-
88-
def decode(jwt, key='', verify=True, **kwargs):
89-
payload, signing_input, header, signature = load(jwt)
90-
91-
if verify:
92-
verify_signature(payload, signing_input, header, signature, key,
93-
**kwargs)
94-
95-
return payload
96-
97-
98-
def load(jwt):
99-
if isinstance(jwt, text_type):
100-
jwt = jwt.encode('utf-8')
101-
try:
102-
signing_input, crypto_segment = jwt.rsplit(b'.', 1)
103-
header_segment, payload_segment = signing_input.split(b'.', 1)
104-
except ValueError:
105-
raise DecodeError('Not enough segments')
106-
107-
try:
108-
header_data = base64url_decode(header_segment)
109-
except (TypeError, binascii.Error):
110-
raise DecodeError('Invalid header padding')
111-
try:
112-
header = json.loads(header_data.decode('utf-8'))
113-
except ValueError as e:
114-
raise DecodeError('Invalid header string: %s' % e)
115-
if not isinstance(header, Mapping):
116-
raise DecodeError('Invalid header string: must be a json object')
117-
118-
try:
119-
payload_data = base64url_decode(payload_segment)
120-
except (TypeError, binascii.Error):
121-
raise DecodeError('Invalid payload padding')
122-
try:
123-
payload = json.loads(payload_data.decode('utf-8'))
124-
except ValueError as e:
125-
raise DecodeError('Invalid payload string: %s' % e)
126-
if not isinstance(payload, Mapping):
127-
raise DecodeError('Invalid payload string: must be a json object')
128-
129-
try:
130-
signature = base64url_decode(crypto_segment)
131-
except (TypeError, binascii.Error):
132-
raise DecodeError('Invalid crypto padding')
133-
134-
return (payload, signing_input, header, signature)
135-
136-
137-
def verify_signature(payload, signing_input, header, signature, key='',
138-
verify_expiration=True, leeway=0, audience=None,
139-
issuer=None):
140-
141-
if isinstance(leeway, timedelta):
142-
leeway = timedelta_total_seconds(leeway)
143-
144-
if not isinstance(audience, (string_types, type(None))):
145-
raise TypeError('audience must be a string or None')
146-
147-
try:
148-
alg_obj = _algorithms[header['alg']]
149-
key = alg_obj.prepare_key(key)
150-
151-
if not alg_obj.verify(signing_input, key, signature):
152-
raise DecodeError('Signature verification failed')
153-
154-
except KeyError:
155-
raise DecodeError('Algorithm not supported')
156-
157-
if 'nbf' in payload and verify_expiration:
158-
utc_timestamp = timegm(datetime.utcnow().utctimetuple())
159-
160-
if payload['nbf'] > (utc_timestamp + leeway):
161-
raise ExpiredSignatureError('Signature not yet valid')
162-
163-
if 'exp' in payload and verify_expiration:
164-
utc_timestamp = timegm(datetime.utcnow().utctimetuple())
165-
166-
if payload['exp'] < (utc_timestamp - leeway):
167-
raise ExpiredSignatureError('Signature has expired')
168-
169-
if 'aud' in payload:
170-
audience_claims = payload['aud']
171-
if isinstance(audience_claims, string_types):
172-
audience_claims = [audience_claims]
173-
if not isinstance(audience_claims, list):
174-
raise InvalidAudienceError('Invalid claim format in token')
175-
if any(not isinstance(c, string_types) for c in audience_claims):
176-
raise InvalidAudienceError('Invalid claim format in token')
177-
if audience not in audience_claims:
178-
raise InvalidAudienceError('Invalid audience')
179-
elif audience is not None:
180-
# Application specified an audience, but it could not be
181-
# verified since the token does not contain a claim.
182-
raise InvalidAudienceError('No audience claim in token')
183-
184-
if issuer is not None:
185-
if payload.get('iss') != issuer:
186-
raise InvalidIssuerError('Invalid issuer')
72+
# Segments
73+
signing_input = b'.'.join(segments)
74+
try:
75+
alg_obj = self._algorithms[algorithm]
76+
key = alg_obj.prepare_key(key)
77+
signature = alg_obj.sign(signing_input, key)
78+
79+
except KeyError:
80+
raise NotImplementedError('Algorithm not supported')
81+
82+
segments.append(base64url_encode(signature))
83+
84+
return b'.'.join(segments)
85+
86+
87+
def decode(self, jwt, key='', verify=True, **kwargs):
88+
payload, signing_input, header, signature = self._load(jwt)
89+
90+
if verify:
91+
self._verify_signature(payload, signing_input, header, signature,
92+
key, **kwargs)
93+
94+
return payload
95+
96+
97+
def _load(self, jwt):
98+
if isinstance(jwt, text_type):
99+
jwt = jwt.encode('utf-8')
100+
try:
101+
signing_input, crypto_segment = jwt.rsplit(b'.', 1)
102+
header_segment, payload_segment = signing_input.split(b'.', 1)
103+
except ValueError:
104+
raise DecodeError('Not enough segments')
105+
106+
try:
107+
header_data = base64url_decode(header_segment)
108+
except (TypeError, binascii.Error):
109+
raise DecodeError('Invalid header padding')
110+
try:
111+
header = json.loads(header_data.decode('utf-8'))
112+
except ValueError as e:
113+
raise DecodeError('Invalid header string: %s' % e)
114+
if not isinstance(header, Mapping):
115+
raise DecodeError('Invalid header string: must be a json object')
116+
117+
try:
118+
payload_data = base64url_decode(payload_segment)
119+
except (TypeError, binascii.Error):
120+
raise DecodeError('Invalid payload padding')
121+
try:
122+
payload = json.loads(payload_data.decode('utf-8'))
123+
except ValueError as e:
124+
raise DecodeError('Invalid payload string: %s' % e)
125+
if not isinstance(payload, Mapping):
126+
raise DecodeError('Invalid payload string: must be a json object')
127+
128+
try:
129+
signature = base64url_decode(crypto_segment)
130+
except (TypeError, binascii.Error):
131+
raise DecodeError('Invalid crypto padding')
132+
133+
return (payload, signing_input, header, signature)
134+
135+
136+
def _verify_signature(self, payload, signing_input, header, signature,
137+
key='', verify_expiration=True, leeway=0,
138+
audience=None, issuer=None):
139+
140+
if isinstance(leeway, timedelta):
141+
leeway = timedelta_total_seconds(leeway)
142+
143+
if not isinstance(audience, (string_types, type(None))):
144+
raise TypeError('audience must be a string or None')
145+
146+
try:
147+
alg_obj = self._algorithms[header['alg']]
148+
key = alg_obj.prepare_key(key)
149+
150+
if not alg_obj.verify(signing_input, key, signature):
151+
raise DecodeError('Signature verification failed')
152+
153+
except KeyError:
154+
raise DecodeError('Algorithm not supported')
155+
156+
if 'nbf' in payload and verify_expiration:
157+
utc_timestamp = timegm(datetime.utcnow().utctimetuple())
158+
159+
if payload['nbf'] > (utc_timestamp + leeway):
160+
raise ExpiredSignatureError('Signature not yet valid')
161+
162+
if 'exp' in payload and verify_expiration:
163+
utc_timestamp = timegm(datetime.utcnow().utctimetuple())
164+
165+
if payload['exp'] < (utc_timestamp - leeway):
166+
raise ExpiredSignatureError('Signature has expired')
167+
168+
if 'aud' in payload:
169+
audience_claims = payload['aud']
170+
if isinstance(audience_claims, string_types):
171+
audience_claims = [audience_claims]
172+
if not isinstance(audience_claims, list):
173+
raise InvalidAudienceError('Invalid claim format in token')
174+
if any(not isinstance(c, string_types) for c in audience_claims):
175+
raise InvalidAudienceError('Invalid claim format in token')
176+
if audience not in audience_claims:
177+
raise InvalidAudienceError('Invalid audience')
178+
elif audience is not None:
179+
# Application specified an audience, but it could not be
180+
# verified since the token does not contain a claim.
181+
raise InvalidAudienceError('No audience claim in token')
182+
183+
if issuer is not None:
184+
if payload.get('iss') != issuer:
185+
raise InvalidIssuerError('Invalid issuer')
186+
187+
_jwt_global_obj = PyJWT()
188+
encode = _jwt_global_obj.encode
189+
decode = _jwt_global_obj.decode
190+
register_algorithm = _jwt_global_obj.register_algorithm

0 commit comments

Comments
 (0)