Skip to content

Commit 3ada770

Browse files
author
Michael Davis
committed
Add flexible and complete verification options
Attempts to fix jpadilla#127
1 parent a2601ad commit 3ada770

File tree

2 files changed

+105
-13
lines changed

2 files changed

+105
-13
lines changed

jwt/api.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
class PyJWT(object):
19-
def __init__(self, algorithms=None):
19+
def __init__(self, algorithms=None, options=None):
2020
self._algorithms = get_default_algorithms()
2121
self._valid_algs = set(algorithms) if algorithms is not None else set(self._algorithms)
2222

@@ -25,6 +25,22 @@ def __init__(self, algorithms=None):
2525
if key not in self._valid_algs:
2626
del self._algorithms[key]
2727

28+
if not options:
29+
options = {}
30+
31+
self.default_options = {
32+
'verify_signature': True,
33+
'verify_exp': True,
34+
'verify_nbf': True,
35+
'verify_iat': True,
36+
'verify_aud': True,
37+
}
38+
39+
try:
40+
self.options = {k: options[k] if k in options else v for k, v in self.default_options.items()}
41+
except (ValueError, TypeError) as e:
42+
raise TypeError('options must be a dictionary: %s' % e)
43+
2844
def register_algorithm(self, alg_id, alg_obj):
2945
"""
3046
Registers a new Algorithm for use when creating and verifying tokens.
@@ -110,14 +126,16 @@ def encode(self, payload, key, algorithm='HS256', headers=None, json_encoder=Non
110126

111127
return b'.'.join(segments)
112128

113-
def decode(self, jwt, key='', verify=True, algorithms=None, **kwargs):
129+
def decode(self, jwt, key='', verify=True, algorithms=None, options=None, **kwargs):
114130
payload, signing_input, header, signature = self._load(jwt)
115131

116132
if verify:
117-
self._verify_signature(payload, signing_input, header, signature,
118-
key, algorithms)
133+
merged_options = self._merge_options(override_options=options)
134+
if merged_options.get('verify_signature'):
135+
self._verify_signature(payload, signing_input, header, signature,
136+
key, algorithms)
119137

120-
self._validate_claims(payload, **kwargs)
138+
self._validate_claims(payload, options=merged_options, **kwargs)
121139

122140
return payload
123141

@@ -177,8 +195,8 @@ def _verify_signature(self, payload, signing_input, header, signature,
177195
except KeyError:
178196
raise InvalidAlgorithmError('Algorithm not supported')
179197

180-
def _validate_claims(self, payload, verify_expiration=True, leeway=0,
181-
audience=None, issuer=None):
198+
def _validate_claims(self, payload, audience=None, issuer=None, leeway=0,
199+
options=None, **kwargs):
182200
if isinstance(leeway, timedelta):
183201
leeway = timedelta_total_seconds(leeway)
184202

@@ -187,7 +205,7 @@ def _validate_claims(self, payload, verify_expiration=True, leeway=0,
187205

188206
now = timegm(datetime.utcnow().utctimetuple())
189207

190-
if 'iat' in payload:
208+
if 'iat' in payload and options.get('verify_iat'):
191209
try:
192210
iat = int(payload['iat'])
193211
except ValueError:
@@ -196,7 +214,7 @@ def _validate_claims(self, payload, verify_expiration=True, leeway=0,
196214
if iat > (now + leeway):
197215
raise InvalidIssuedAtError('Issued At claim (iat) cannot be in the future.')
198216

199-
if 'nbf' in payload and verify_expiration:
217+
if 'nbf' in payload and options.get('verify_nbf'):
200218
try:
201219
nbf = int(payload['nbf'])
202220
except ValueError:
@@ -205,7 +223,7 @@ def _validate_claims(self, payload, verify_expiration=True, leeway=0,
205223
if nbf > (now + leeway):
206224
raise ImmatureSignatureError('The token is not yet valid (nbf)')
207225

208-
if 'exp' in payload and verify_expiration:
226+
if 'exp' in payload and options.get('verify_exp'):
209227
try:
210228
exp = int(payload['exp'])
211229
except ValueError:
@@ -214,7 +232,7 @@ def _validate_claims(self, payload, verify_expiration=True, leeway=0,
214232
if exp < (now - leeway):
215233
raise ExpiredSignatureError('Signature has expired')
216234

217-
if 'aud' in payload:
235+
if 'aud' in payload and options.get('verify_aud'):
218236
audience_claims = payload['aud']
219237
if isinstance(audience_claims, string_types):
220238
audience_claims = [audience_claims]
@@ -233,6 +251,15 @@ def _validate_claims(self, payload, verify_expiration=True, leeway=0,
233251
if payload.get('iss') != issuer:
234252
raise InvalidIssuerError('Invalid issuer')
235253

254+
def _merge_options(self, override_options=None):
255+
if not override_options:
256+
override_options = {}
257+
try:
258+
options = {k: override_options[k] if k in override_options else v for k, v in self.options.items()}
259+
except (ValueError, TypeError) as e:
260+
raise TypeError('options must be a dictionary: %s' % e)
261+
return options
262+
236263

237264
_jwt_global_obj = PyJWT()
238265
encode = _jwt_global_obj.encode

tests/test_api.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,26 @@ def test_algorithms_parameter_removes_alg_from_algorithms_list(self):
7171
self.assertNotIn('none', self.jwt.get_algorithms())
7272
self.assertIn('HS256', self.jwt.get_algorithms())
7373

74+
def test_default_options(self):
75+
self.assertEqual(self.jwt.default_options, self.jwt.options)
76+
77+
def test_override_options(self):
78+
self.jwt = PyJWT(options={'verify_exp': False, 'verify_nbf': False})
79+
expected_options = self.jwt.default_options
80+
expected_options['verify_exp'] = False
81+
expected_options['verify_nbf'] = False
82+
self.assertEqual(expected_options, self.jwt.options)
83+
84+
def test_non_existant_options_dont_exist(self):
85+
self.jwt = PyJWT(options={'verify_iat': False, 'foobar': False})
86+
expected_options = self.jwt.default_options
87+
expected_options['verify_iat'] = False
88+
self.assertEqual(expected_options, self.jwt.options)
89+
self.assertNotIn('foobar', self.jwt.options)
90+
91+
def test_options_must_be_dict(self):
92+
self.assertRaises(TypeError, PyJWT, options=object())
93+
7494
def test_encode_decode(self):
7595
secret = 'secret'
7696
jwt_message = self.jwt.encode(self.payload, secret)
@@ -467,14 +487,14 @@ def test_decode_skip_expiration_verification(self):
467487
secret = 'secret'
468488
jwt_message = self.jwt.encode(self.payload, secret)
469489

470-
self.jwt.decode(jwt_message, secret, verify_expiration=False)
490+
self.jwt.decode(jwt_message, secret, options={'verify_exp': False})
471491

472492
def test_decode_skip_notbefore_verification(self):
473493
self.payload['nbf'] = time.time() + 10
474494
secret = 'secret'
475495
jwt_message = self.jwt.encode(self.payload, secret)
476496

477-
self.jwt.decode(jwt_message, secret, verify_expiration=False)
497+
self.jwt.decode(jwt_message, secret, options={'verify_nbf': False})
478498

479499
def test_decode_with_expiration_with_leeway(self):
480500
self.payload['exp'] = utc_timestamp() - 2
@@ -765,6 +785,51 @@ def test_raise_exception_token_without_issuer(self):
765785
with self.assertRaises(InvalidIssuerError):
766786
self.jwt.decode(token, 'secret', issuer=issuer)
767787

788+
def test_skip_check_audience(self):
789+
payload = {
790+
'some': 'payload',
791+
'aud': 'urn:me',
792+
}
793+
token = self.jwt.encode(payload, 'secret')
794+
self.jwt.decode(token, 'secret', options={'verify_aud': False})
795+
796+
def test_skip_check_exp(self):
797+
payload = {
798+
'some': 'payload',
799+
'exp': datetime.utcnow() - timedelta(days=1)
800+
}
801+
token = self.jwt.encode(payload, 'secret')
802+
self.jwt.decode(token, 'secret', options={'verify_exp': False})
803+
804+
def test_skip_check_signature(self):
805+
token = ("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
806+
".eyJzb21lIjoicGF5bG9hZCJ9"
807+
".4twFt5NiznN84AWoo1d7KO1T_yoc0Z6XOpOVswacPZA")
808+
self.jwt.decode(token, 'secret', options={'verify_signature': False})
809+
810+
def test_skip_check_iat(self):
811+
payload = {
812+
'some': 'payload',
813+
'iat': datetime.utcnow() + timedelta(days=1)
814+
}
815+
token = self.jwt.encode(payload, 'secret')
816+
self.jwt.decode(token, 'secret', options={'verify_iat': False})
817+
818+
def test_skip_check_nbf(self):
819+
payload = {
820+
'some': 'payload',
821+
'nbf': datetime.utcnow() + timedelta(days=1)
822+
}
823+
token = self.jwt.encode(payload, 'secret')
824+
self.jwt.decode(token, 'secret', options={'verify_nbf': False})
825+
826+
def test_decode_options_must_be_dict(self):
827+
payload = {
828+
'some': 'payload',
829+
}
830+
token = self.jwt.encode(payload, 'secret')
831+
self.assertRaises(TypeError, self.jwt.decode, token, 'secret', options=object())
832+
768833
def test_custom_json_encoder(self):
769834

770835
class CustomJSONEncoder(json.JSONEncoder):

0 commit comments

Comments
 (0)