Skip to content

Commit b39b9a7

Browse files
committed
Merge pull request jpadilla#135 from mark-adams/minor-updates
Minor refactorings to make things a little cleaner
2 parents 5fd54be + 90577f7 commit b39b9a7

File tree

4 files changed

+50
-35
lines changed

4 files changed

+50
-35
lines changed

jwt/api.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
InvalidAlgorithmError, InvalidAudienceError, InvalidIssuedAtError,
1414
InvalidIssuerError
1515
)
16-
from .utils import base64url_decode, base64url_encode
16+
from .utils import base64url_decode, base64url_encode, merge_dict
1717

1818

1919
class PyJWT(object):
@@ -29,15 +29,15 @@ def __init__(self, algorithms=None, options=None):
2929
if not options:
3030
options = {}
3131

32-
self.default_options = {
32+
default_options = {
3333
'verify_signature': True,
3434
'verify_exp': True,
3535
'verify_nbf': True,
3636
'verify_iat': True,
3737
'verify_aud': True,
3838
}
3939

40-
self.options = self._merge_options(self.default_options, options)
40+
self.options = merge_dict(default_options, options)
4141

4242
def register_algorithm(self, alg_id, alg_obj):
4343
"""
@@ -85,6 +85,7 @@ def encode(self, payload, key, algorithm='HS256', headers=None, json_encoder=Non
8585

8686
# Header
8787
header = {'typ': 'JWT', 'alg': algorithm}
88+
8889
if headers:
8990
header.update(headers)
9091

@@ -128,7 +129,7 @@ def decode(self, jwt, key='', verify=True, algorithms=None, options=None, **kwar
128129
payload, signing_input, header, signature = self._load(jwt)
129130

130131
if verify:
131-
merged_options = self._merge_options(override_options=options)
132+
merged_options = merge_dict(self.options, options)
132133
if merged_options.get('verify_signature'):
133134
self._verify_signature(payload, signing_input, header, signature,
134135
key, algorithms)
@@ -251,21 +252,6 @@ def _validate_claims(self, payload, audience=None, issuer=None, leeway=0,
251252
if payload.get('iss') != issuer:
252253
raise InvalidIssuerError('Invalid issuer')
253254

254-
def _merge_options(self, default_options=None, override_options=None):
255-
if not default_options:
256-
default_options = {}
257-
258-
if not override_options:
259-
override_options = {}
260-
261-
try:
262-
merged_options = self.default_options.copy()
263-
merged_options.update(override_options)
264-
except (AttributeError, ValueError) as e:
265-
raise TypeError('options must be a dictionary: %s' % e)
266-
267-
return merged_options
268-
269255

270256
_jwt_global_obj = PyJWT()
271257
encode = _jwt_global_obj.encode

jwt/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,16 @@ def base64url_decode(input):
1212

1313
def base64url_encode(input):
1414
return base64.urlsafe_b64encode(input).replace(b'=', b'')
15+
16+
17+
def merge_dict(original, updates):
18+
if not updates:
19+
return original
20+
21+
try:
22+
merged_options = original.copy()
23+
merged_options.update(updates)
24+
except (AttributeError, ValueError) as e:
25+
raise TypeError('original and updates must be a dictionary: %s' % e)
26+
27+
return merged_options

tests/test_algorithms.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,31 +113,31 @@ def test_rsa_should_reject_non_string_key(self):
113113
def test_rsa_verify_should_return_false_if_signature_invalid(self):
114114
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
115115

116-
jwt_message = ensure_bytes('Hello World!')
116+
message = ensure_bytes('Hello World!')
117117

118-
jwt_sig = base64.b64decode(ensure_bytes(
118+
sig = base64.b64decode(ensure_bytes(
119119
'yS6zk9DBkuGTtcBzLUzSpo9gGJxJFOGvUqN01iLhWHrzBQ9ZEz3+Ae38AXp'
120120
'10RWwscp42ySC85Z6zoN67yGkLNWnfmCZSEv+xqELGEvBJvciOKsrhiObUl'
121121
'2mveSc1oeO/2ujkGDkkkJ2epn0YliacVjZF5+/uDmImUfAAj8lzjnHlzYix'
122122
'sn5jGz1H07jYYbi9diixN8IUhXeTafwFg02IcONhum29V40Wu6O5tAKWlJX'
123123
'fHJnNUzAEUOXS0WahHVb57D30pcgIji9z923q90p5c7E2cU8V+E1qe8NdCA'
124124
'APCDzZZ9zQ/dgcMVaBrGrgimrcLbPjueOKFgSO+SSjIElKA=='))
125125

126-
jwt_sig += ensure_bytes('123') # Signature is now invalid
126+
sig += ensure_bytes('123') # Signature is now invalid
127127

128128
with open(key_path('testkey_rsa.pub'), 'r') as keyfile:
129-
jwt_pub_key = algo.prepare_key(keyfile.read())
129+
pub_key = algo.prepare_key(keyfile.read())
130130

131-
result = algo.verify(jwt_message, jwt_pub_key, jwt_sig)
131+
result = algo.verify(message, pub_key, sig)
132132
self.assertFalse(result)
133133

134134
@unittest.skipIf(not has_crypto, 'Not supported without cryptography library')
135135
def test_rsa_verify_should_return_true_if_signature_valid(self):
136136
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
137137

138-
jwt_message = ensure_bytes('Hello World!')
138+
message = ensure_bytes('Hello World!')
139139

140-
jwt_sig = base64.b64decode(ensure_bytes(
140+
sig = base64.b64decode(ensure_bytes(
141141
'yS6zk9DBkuGTtcBzLUzSpo9gGJxJFOGvUqN01iLhWHrzBQ9ZEz3+Ae38AXp'
142142
'10RWwscp42ySC85Z6zoN67yGkLNWnfmCZSEv+xqELGEvBJvciOKsrhiObUl'
143143
'2mveSc1oeO/2ujkGDkkkJ2epn0YliacVjZF5+/uDmImUfAAj8lzjnHlzYix'
@@ -146,9 +146,9 @@ def test_rsa_verify_should_return_true_if_signature_valid(self):
146146
'APCDzZZ9zQ/dgcMVaBrGrgimrcLbPjueOKFgSO+SSjIElKA=='))
147147

148148
with open(key_path('testkey_rsa.pub'), 'r') as keyfile:
149-
jwt_pub_key = algo.prepare_key(keyfile.read())
149+
pub_key = algo.prepare_key(keyfile.read())
150150

151-
result = algo.verify(jwt_message, jwt_pub_key, jwt_sig)
151+
result = algo.verify(message, pub_key, sig)
152152
self.assertTrue(result)
153153

154154
@unittest.skipIf(not has_crypto, 'Not supported without cryptography library')

tests/test_api.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
InvalidAlgorithmError, InvalidAudienceError, InvalidIssuedAtError,
1515
InvalidIssuerError
1616
)
17+
from jwt.utils import base64url_decode
1718

18-
from .compat import text_type, unittest
19+
from .compat import string_types, text_type, unittest
1920
from .utils import ensure_bytes
2021

2122
try:
@@ -80,19 +81,16 @@ def test_algorithms_parameter_removes_alg_from_algorithms_list(self):
8081
self.assertNotIn('none', self.jwt.get_algorithms())
8182
self.assertIn('HS256', self.jwt.get_algorithms())
8283

83-
def test_default_options(self):
84-
self.assertEqual(self.jwt.default_options, self.jwt.options)
85-
8684
def test_override_options(self):
8785
self.jwt = PyJWT(options={'verify_exp': False, 'verify_nbf': False})
88-
expected_options = self.jwt.default_options
86+
expected_options = self.jwt.options
8987
expected_options['verify_exp'] = False
9088
expected_options['verify_nbf'] = False
9189
self.assertEqual(expected_options, self.jwt.options)
9290

93-
def test_non_default_options_persist(self):
91+
def test_non_object_options_persist(self):
9492
self.jwt = PyJWT(options={'verify_iat': False, 'foobar': False})
95-
expected_options = self.jwt.default_options
93+
expected_options = self.jwt.options
9694
expected_options['verify_iat'] = False
9795
expected_options['foobar'] = False
9896
self.assertEqual(expected_options, self.jwt.options)
@@ -880,6 +878,24 @@ def default(self, o):
880878
payload = self.jwt.decode(token, 'secret')
881879
self.assertEqual(payload, {'some_decimal': 'it worked'})
882880

881+
def test_encode_headers_parameter_adds_headers(self):
882+
headers = {'testheader': True}
883+
token = self.jwt.encode({'msg': 'hello world'}, 'secret', headers=headers)
884+
885+
if not isinstance(token, string_types):
886+
token = token.decode()
887+
888+
header = token[0:token.index('.')].encode()
889+
header = base64url_decode(header)
890+
891+
if not isinstance(header, text_type):
892+
header = header.decode()
893+
894+
header_obj = json.loads(header)
895+
896+
self.assertIn('testheader', header_obj)
897+
self.assertEqual(header_obj['testheader'], headers['testheader'])
898+
883899

884900
if __name__ == '__main__':
885901
unittest.main()

0 commit comments

Comments
 (0)