Skip to content

Commit f013347

Browse files
authored
Merge pull request jpadilla#202 from jpadilla/add-jwk-for-hmac-rsa
Add JWK support for HMAC and RSA keys
2 parents 3edaa53 + 42b0114 commit f013347

File tree

11 files changed

+554
-130
lines changed

11 files changed

+554
-130
lines changed

jwt/algorithms.py

Lines changed: 144 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11
import hashlib
22
import hmac
3+
import json
34

4-
from .compat import binary_type, constant_time_compare, is_string_type
5+
6+
from .compat import constant_time_compare, string_types
57
from .exceptions import InvalidKeyError
6-
from .utils import der_to_raw_signature, raw_to_der_signature
8+
from .utils import (
9+
base64url_decode, base64url_encode, der_to_raw_signature,
10+
force_bytes, force_unicode, from_base64url_uint, raw_to_der_signature,
11+
to_base64url_uint
12+
)
713

814
try:
915
from cryptography.hazmat.primitives import hashes
1016
from cryptography.hazmat.primitives.serialization import (
1117
load_pem_private_key, load_pem_public_key, load_ssh_public_key
1218
)
1319
from cryptography.hazmat.primitives.asymmetric.rsa import (
14-
RSAPrivateKey, RSAPublicKey
20+
RSAPrivateKey, RSAPublicKey, RSAPrivateNumbers, RSAPublicNumbers,
21+
rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp
1522
)
1623
from cryptography.hazmat.primitives.asymmetric.ec import (
1724
EllipticCurvePrivateKey, EllipticCurvePublicKey
@@ -77,6 +84,20 @@ def verify(self, msg, key, sig):
7784
"""
7885
raise NotImplementedError
7986

87+
@staticmethod
88+
def to_jwk(key_obj):
89+
"""
90+
Serializes a given RSA key into a JWK
91+
"""
92+
raise NotImplementedError
93+
94+
@staticmethod
95+
def from_jwk(jwk):
96+
"""
97+
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
98+
"""
99+
raise NotImplementedError
100+
80101

81102
class NoneAlgorithm(Algorithm):
82103
"""
@@ -112,11 +133,7 @@ def __init__(self, hash_alg):
112133
self.hash_alg = hash_alg
113134

114135
def prepare_key(self, key):
115-
if not is_string_type(key):
116-
raise TypeError('Expecting a string- or bytes-formatted key.')
117-
118-
if not isinstance(key, binary_type):
119-
key = key.encode('utf-8')
136+
key = force_bytes(key)
120137

121138
invalid_strings = [
122139
b'-----BEGIN PUBLIC KEY-----',
@@ -131,6 +148,22 @@ def prepare_key(self, key):
131148

132149
return key
133150

151+
@staticmethod
152+
def to_jwk(key_obj):
153+
return json.dumps({
154+
'k': force_unicode(base64url_encode(force_bytes(key_obj))),
155+
'kty': 'oct'
156+
})
157+
158+
@staticmethod
159+
def from_jwk(jwk):
160+
obj = json.loads(jwk)
161+
162+
if obj.get('kty') != 'oct':
163+
raise InvalidKeyError('Not an HMAC key')
164+
165+
return base64url_decode(obj['k'])
166+
134167
def sign(self, msg, key):
135168
return hmac.new(key, msg, self.hash_alg).digest()
136169

@@ -156,9 +189,8 @@ def prepare_key(self, key):
156189
isinstance(key, RSAPublicKey):
157190
return key
158191

159-
if is_string_type(key):
160-
if not isinstance(key, binary_type):
161-
key = key.encode('utf-8')
192+
if isinstance(key, string_types):
193+
key = force_bytes(key)
162194

163195
try:
164196
if key.startswith(b'ssh-rsa'):
@@ -172,6 +204,105 @@ def prepare_key(self, key):
172204

173205
return key
174206

207+
@staticmethod
208+
def to_jwk(key_obj):
209+
obj = None
210+
211+
if getattr(key_obj, 'private_numbers', None):
212+
# Private key
213+
numbers = key_obj.private_numbers()
214+
215+
obj = {
216+
'kty': 'RSA',
217+
'key_ops': ['sign'],
218+
'n': force_unicode(to_base64url_uint(numbers.public_numbers.n)),
219+
'e': force_unicode(to_base64url_uint(numbers.public_numbers.e)),
220+
'd': force_unicode(to_base64url_uint(numbers.d)),
221+
'p': force_unicode(to_base64url_uint(numbers.p)),
222+
'q': force_unicode(to_base64url_uint(numbers.q)),
223+
'dp': force_unicode(to_base64url_uint(numbers.dmp1)),
224+
'dq': force_unicode(to_base64url_uint(numbers.dmq1)),
225+
'qi': force_unicode(to_base64url_uint(numbers.iqmp))
226+
}
227+
228+
elif getattr(key_obj, 'verifier', None):
229+
# Public key
230+
numbers = key_obj.public_numbers()
231+
232+
obj = {
233+
'kty': 'RSA',
234+
'key_ops': ['verify'],
235+
'n': force_unicode(to_base64url_uint(numbers.n)),
236+
'e': force_unicode(to_base64url_uint(numbers.e))
237+
}
238+
else:
239+
raise InvalidKeyError('Not a public or private key')
240+
241+
return json.dumps(obj)
242+
243+
@staticmethod
244+
def from_jwk(jwk):
245+
try:
246+
obj = json.loads(jwk)
247+
except ValueError:
248+
raise InvalidKeyError('Key is not valid JSON')
249+
250+
if obj.get('kty') != 'RSA':
251+
raise InvalidKeyError('Not an RSA key')
252+
253+
if 'd' in obj and 'e' in obj and 'n' in obj:
254+
# Private key
255+
if 'oth' in obj:
256+
raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported')
257+
258+
other_props = ['p', 'q', 'dp', 'dq', 'qi']
259+
props_found = [prop in obj for prop in other_props]
260+
any_props_found = any(props_found)
261+
262+
if any_props_found and not all(props_found):
263+
raise InvalidKeyError('RSA key must include all parameters if any are present besides d')
264+
265+
public_numbers = RSAPublicNumbers(
266+
from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
267+
)
268+
269+
if any_props_found:
270+
numbers = RSAPrivateNumbers(
271+
d=from_base64url_uint(obj['d']),
272+
p=from_base64url_uint(obj['p']),
273+
q=from_base64url_uint(obj['q']),
274+
dmp1=from_base64url_uint(obj['dp']),
275+
dmq1=from_base64url_uint(obj['dq']),
276+
iqmp=from_base64url_uint(obj['qi']),
277+
public_numbers=public_numbers
278+
)
279+
else:
280+
d = from_base64url_uint(obj['d'])
281+
p, q = rsa_recover_prime_factors(
282+
public_numbers.n, d, public_numbers.e
283+
)
284+
285+
numbers = RSAPrivateNumbers(
286+
d=d,
287+
p=p,
288+
q=q,
289+
dmp1=rsa_crt_dmp1(d, p),
290+
dmq1=rsa_crt_dmq1(d, q),
291+
iqmp=rsa_crt_iqmp(p, q),
292+
public_numbers=public_numbers
293+
)
294+
295+
return numbers.private_key(default_backend())
296+
elif 'n' in obj and 'e' in obj:
297+
# Public key
298+
numbers = RSAPublicNumbers(
299+
from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
300+
)
301+
302+
return numbers.public_key(default_backend())
303+
else:
304+
raise InvalidKeyError('Not a public or private key')
305+
175306
def sign(self, msg, key):
176307
signer = key.signer(
177308
padding.PKCS1v15(),
@@ -213,9 +344,8 @@ def prepare_key(self, key):
213344
isinstance(key, EllipticCurvePublicKey):
214345
return key
215346

216-
if is_string_type(key):
217-
if not isinstance(key, binary_type):
218-
key = key.encode('utf-8')
347+
if isinstance(key, string_types):
348+
key = force_bytes(key)
219349

220350
# Attempt to load key. We don't know if it's
221351
# a Signing Key or a Verifying Key, so we try

jwt/api_jws.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .algorithms import Algorithm, get_default_algorithms # NOQA
88
from .compat import binary_type, string_types, text_type
99
from .exceptions import DecodeError, InvalidAlgorithmError, InvalidTokenError
10-
from .utils import base64url_decode, base64url_encode, merge_dict
10+
from .utils import base64url_decode, base64url_encode, force_bytes, merge_dict
1111

1212

1313
class PyJWS(object):
@@ -82,11 +82,13 @@ def encode(self, payload, key, algorithm='HS256', headers=None,
8282
self._validate_headers(headers)
8383
header.update(headers)
8484

85-
json_header = json.dumps(
86-
header,
87-
separators=(',', ':'),
88-
cls=json_encoder
89-
).encode('utf-8')
85+
json_header = force_bytes(
86+
json.dumps(
87+
header,
88+
separators=(',', ':'),
89+
cls=json_encoder
90+
)
91+
)
9092

9193
segments.append(base64url_encode(json_header))
9294
segments.append(base64url_encode(payload))

jwt/compat.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
versions of python, and compatibility wrappers around optional packages.
44
"""
55
# flake8: noqa
6-
import sys
76
import hmac
7+
import struct
8+
import sys
89

910

1011
PY3 = sys.version_info[0] == 3
@@ -20,10 +21,6 @@
2021
string_types = (text_type, binary_type)
2122

2223

23-
def is_string_type(val):
24-
return any([isinstance(val, typ) for typ in string_types])
25-
26-
2724
def timedelta_total_seconds(delta):
2825
try:
2926
delta.total_seconds
@@ -56,3 +53,24 @@ def constant_time_compare(val1, val2):
5653
result |= ord(x) ^ ord(y)
5754

5855
return result == 0
56+
57+
# Use int.to_bytes if it exists (Python 3)
58+
if getattr(int, 'to_bytes', None):
59+
def bytes_from_int(val):
60+
remaining = val
61+
byte_length = 0
62+
63+
while remaining != 0:
64+
remaining = remaining >> 8
65+
byte_length += 1
66+
67+
return val.to_bytes(byte_length, 'big', signed=False)
68+
else:
69+
def bytes_from_int(val):
70+
buf = []
71+
while val:
72+
val, remainder = divmod(val, 256)
73+
buf.append(remainder)
74+
75+
buf.reverse()
76+
return struct.pack('%sB' % len(buf), *buf)

jwt/utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import base64
22
import binascii
3+
import struct
4+
5+
from .compat import binary_type, bytes_from_int, text_type
36

47
try:
58
from cryptography.hazmat.primitives.asymmetric.utils import (
@@ -9,7 +12,28 @@
912
pass
1013

1114

15+
def force_unicode(value):
16+
if isinstance(value, binary_type):
17+
return value.decode('utf-8')
18+
elif isinstance(value, text_type):
19+
return value
20+
else:
21+
raise TypeError('Expected a string value')
22+
23+
24+
def force_bytes(value):
25+
if isinstance(value, text_type):
26+
return value.encode('utf-8')
27+
elif isinstance(value, binary_type):
28+
return value
29+
else:
30+
raise TypeError('Expected a string value')
31+
32+
1233
def base64url_decode(input):
34+
if isinstance(input, text_type):
35+
input = input.encode('ascii')
36+
1337
rem = len(input) % 4
1438

1539
if rem > 0:
@@ -22,6 +46,28 @@ def base64url_encode(input):
2246
return base64.urlsafe_b64encode(input).replace(b'=', b'')
2347

2448

49+
def to_base64url_uint(val):
50+
if val < 0:
51+
raise ValueError('Must be a positive integer')
52+
53+
int_bytes = bytes_from_int(val)
54+
55+
if len(int_bytes) == 0:
56+
int_bytes = b'\x00'
57+
58+
return base64url_encode(int_bytes)
59+
60+
61+
def from_base64url_uint(val):
62+
if isinstance(val, text_type):
63+
val = val.encode('ascii')
64+
65+
data = base64url_decode(val)
66+
67+
buf = struct.unpack('%sB' % len(data), data)
68+
return int(''.join(["%02x" % byte for byte in buf]), 16)
69+
70+
2571
def merge_dict(original, updates):
2672
if not updates:
2773
return original

0 commit comments

Comments
 (0)