Skip to content

Commit 94d102b

Browse files
Split PyJWT/PyJWS classes to tighten type interfaces (jpadilla#559)
The class PyJWT was previously a subclass of PyJWS. However, this combination does not follow the Liskov substitution principle. That is, using PyJWT in place of a PyJWS would not produce correct results or follow type contracts. While these classes look to share a common interface it doesn't go beyond the method names "encode" and "decode" and so is merely superficial. The classes have been split into two. PyJWT now uses composition instead of inheritance to achieve the desired behavior. Splitting the classes in this way allowed for precising the type interfaces. The complete parameter to .decode() has been removed. This argument was used to alter the return type of .decode(). Now, there are two different methods with more explicit return types and values. The new method name is .decode_complete(). This fills the previous role filled by .decode(..., complete=True). Closes jpadilla#554, jpadilla#396, jpadilla#394 Co-authored-by: Sam Bull <git@sambull.org> Co-authored-by: Sam Bull <git@sambull.org>
1 parent 2e1e69d commit 94d102b

File tree

7 files changed

+94
-46
lines changed

7 files changed

+94
-46
lines changed

jwt/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
from .api_jws import PyJWS
2-
from .api_jwt import (
3-
PyJWT,
4-
decode,
5-
encode,
1+
from .api_jws import (
2+
PyJWS,
63
get_unverified_header,
74
register_algorithm,
85
unregister_algorithm,
96
)
7+
from .api_jwt import PyJWT, decode, encode
108
from .exceptions import (
119
DecodeError,
1210
ExpiredSignatureError,

jwt/api_jws.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import binascii
22
import json
33
from collections.abc import Mapping
4-
from typing import Dict, List, Optional, Type, Union
4+
from typing import Any, Dict, List, Optional, Type
55

66
from .algorithms import (
77
Algorithm,
@@ -77,7 +77,7 @@ def get_algorithms(self):
7777

7878
def encode(
7979
self,
80-
payload: Union[Dict, bytes],
80+
payload: bytes,
8181
key: str,
8282
algorithm: str = "HS256",
8383
headers: Optional[Dict] = None,
@@ -127,15 +127,14 @@ def encode(
127127

128128
return encoded_string.decode("utf-8")
129129

130-
def decode(
130+
def decode_complete(
131131
self,
132132
jwt: str,
133133
key: str = "",
134134
algorithms: List[str] = None,
135135
options: Dict = None,
136-
complete: bool = False,
137136
**kwargs,
138-
):
137+
) -> Dict[str, Any]:
139138
if options is None:
140139
options = {}
141140
merged_options = {**self.options, **options}
@@ -153,14 +152,22 @@ def decode(
153152
payload, signing_input, header, signature, key, algorithms
154153
)
155154

156-
if complete:
157-
return {
158-
"payload": payload,
159-
"header": header,
160-
"signature": signature,
161-
}
155+
return {
156+
"payload": payload,
157+
"header": header,
158+
"signature": signature,
159+
}
162160

163-
return payload
161+
def decode(
162+
self,
163+
jwt: str,
164+
key: str = "",
165+
algorithms: List[str] = None,
166+
options: Dict = None,
167+
**kwargs,
168+
) -> str:
169+
decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs)
170+
return decoded["payload"]
164171

165172
def get_unverified_header(self, jwt):
166173
"""Returns back the JWT header parameters as a dict()
@@ -249,6 +256,7 @@ def _validate_kid(self, kid):
249256

250257
_jws_global_obj = PyJWS()
251258
encode = _jws_global_obj.encode
259+
decode_complete = _jws_global_obj.decode_complete
252260
decode = _jws_global_obj.decode
253261
register_algorithm = _jws_global_obj.register_algorithm
254262
unregister_algorithm = _jws_global_obj.unregister_algorithm

jwt/api_jwt.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from datetime import datetime, timedelta
55
from typing import Any, Dict, List, Optional, Type, Union
66

7-
from .api_jws import PyJWS
7+
from . import api_jws
88
from .exceptions import (
99
DecodeError,
1010
ExpiredSignatureError,
@@ -16,8 +16,11 @@
1616
)
1717

1818

19-
class PyJWT(PyJWS):
20-
header_type = "JWT"
19+
class PyJWT:
20+
def __init__(self, options=None):
21+
if options is None:
22+
options = {}
23+
self.options = {**self._get_default_options(), **options}
2124

2225
@staticmethod
2326
def _get_default_options() -> Dict[str, Union[bool, List[str]]]:
@@ -33,7 +36,7 @@ def _get_default_options() -> Dict[str, Union[bool, List[str]]]:
3336

3437
def encode(
3538
self,
36-
payload: Union[Dict, bytes],
39+
payload: Dict[str, Any],
3740
key: str,
3841
algorithm: str = "HS256",
3942
headers: Optional[Dict] = None,
@@ -59,20 +62,18 @@ def encode(
5962
payload, separators=(",", ":"), cls=json_encoder
6063
).encode("utf-8")
6164

62-
return super().encode(
65+
return api_jws.encode(
6366
json_payload, key, algorithm, headers, json_encoder
6467
)
6568

66-
def decode(
69+
def decode_complete(
6770
self,
6871
jwt: str,
6972
key: str = "",
7073
algorithms: List[str] = None,
7174
options: Dict = None,
72-
complete: bool = False,
7375
**kwargs,
7476
) -> Dict[str, Any]:
75-
7677
if options is None:
7778
options = {"verify_signature": True}
7879
else:
@@ -83,20 +84,16 @@ def decode(
8384
'It is required that you pass in a value for the "algorithms" argument when calling decode().'
8485
)
8586

86-
decoded = super().decode(
87+
decoded = api_jws.decode_complete(
8788
jwt,
8889
key=key,
8990
algorithms=algorithms,
9091
options=options,
91-
complete=complete,
9292
**kwargs,
9393
)
9494

9595
try:
96-
if complete:
97-
payload = json.loads(decoded["payload"])
98-
else:
99-
payload = json.loads(decoded)
96+
payload = json.loads(decoded["payload"])
10097
except ValueError as e:
10198
raise DecodeError("Invalid payload string: %s" % e)
10299
if not isinstance(payload, dict):
@@ -106,11 +103,19 @@ def decode(
106103
merged_options = {**self.options, **options}
107104
self._validate_claims(payload, merged_options, **kwargs)
108105

109-
if complete:
110-
decoded["payload"] = payload
111-
return decoded
106+
decoded["payload"] = payload
107+
return decoded
112108

113-
return payload
109+
def decode(
110+
self,
111+
jwt: str,
112+
key: str = "",
113+
algorithms: List[str] = None,
114+
options: Dict = None,
115+
**kwargs,
116+
) -> Dict[str, Any]:
117+
decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs)
118+
return decoded["payload"]
114119

115120
def _validate_claims(
116121
self, payload, options, audience=None, issuer=None, leeway=0, **kwargs
@@ -215,7 +220,5 @@ def _validate_iss(self, payload, issuer):
215220

216221
_jwt_global_obj = PyJWT()
217222
encode = _jwt_global_obj.encode
223+
decode_complete = _jwt_global_obj.decode_complete
218224
decode = _jwt_global_obj.decode
219-
register_algorithm = _jwt_global_obj.register_algorithm
220-
unregister_algorithm = _jwt_global_obj.unregister_algorithm
221-
get_unverified_header = _jwt_global_obj.get_unverified_header

jwt/jwks_client.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import urllib.request
33

44
from .api_jwk import PyJWKSet
5-
from .api_jwt import decode as decode_token
5+
from .api_jwt import decode_complete as decode_token
66
from .exceptions import PyJWKClientError
77

88

@@ -50,8 +50,6 @@ def get_signing_key(self, kid):
5050
return signing_key
5151

5252
def get_signing_key_from_jwt(self, token):
53-
unverified = decode_token(
54-
token, complete=True, options={"verify_signature": False}
55-
)
53+
unverified = decode_token(token, options={"verify_signature": False})
5654
header = unverified["header"]
5755
return self.get_signing_key(header.get("kid"))

jwt/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def base64url_decode(input):
3232
return base64.urlsafe_b64decode(input)
3333

3434

35-
def base64url_encode(input):
35+
def base64url_encode(input: bytes) -> bytes:
3636
return base64.urlsafe_b64encode(input).replace(b"=", b"")
3737

3838

tests/test_api_jws.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,27 @@ def test_decodes_valid_jws(self, jws, payload):
215215

216216
assert decoded_payload == payload
217217

218+
def test_decodes_complete_valid_jws(self, jws, payload):
219+
example_secret = "secret"
220+
example_jws = (
221+
b"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9."
222+
b"aGVsbG8gd29ybGQ."
223+
b"gEW0pdU4kxPthjtehYdhxB9mMOGajt1xCKlGGXDJ8PM"
224+
)
225+
226+
decoded = jws.decode_complete(
227+
example_jws, example_secret, algorithms=["HS256"]
228+
)
229+
230+
assert decoded == {
231+
"header": {"alg": "HS256", "typ": "JWT"},
232+
"payload": payload,
233+
"signature": (
234+
b"\x80E\xb4\xa5\xd58\x93\x13\xed\x86;^\x85\x87a\xc4"
235+
b"\x1ff0\xe1\x9a\x8e\xddq\x08\xa9F\x19p\xc9\xf0\xf3"
236+
),
237+
}
238+
218239
# 'Control' Elliptic Curve jws created by another library.
219240
# Used to test for regressions that could affect both
220241
# encoding / decoding operations equally (causing tests

tests/test_api_jwt.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,27 @@ def test_decodes_valid_jwt(self, jwt):
4747

4848
assert decoded_payload == example_payload
4949

50+
def test_decodes_complete_valid_jwt(self, jwt):
51+
example_payload = {"hello": "world"}
52+
example_secret = "secret"
53+
example_jwt = (
54+
b"eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9"
55+
b".eyJoZWxsbyI6ICJ3b3JsZCJ9"
56+
b".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8"
57+
)
58+
decoded = jwt.decode_complete(
59+
example_jwt, example_secret, algorithms=["HS256"]
60+
)
61+
62+
assert decoded == {
63+
"header": {"alg": "HS256", "typ": "JWT"},
64+
"payload": example_payload,
65+
"signature": (
66+
b'\xb6\xf6\xa0,2\xe8j"J\xc4\xe2\xaa\xa4\x15\xd2'
67+
b"\x10l\xbbI\x84\xa2}\x98c\x9e\xd8&\xf5\xcbi\xca?"
68+
),
69+
}
70+
5071
def test_load_verify_valid_jwt(self, jwt):
5172
example_payload = {"hello": "world"}
5273
example_secret = "secret"
@@ -313,13 +334,12 @@ def test_decode_with_expiration_with_leeway(self, jwt, payload):
313334
secret = "secret"
314335
jwt_message = jwt.encode(payload, secret)
315336

316-
decoded_payload, signing, header, signature = jwt._load(jwt_message)
317-
318337
# With 3 seconds leeway, should be ok
319338
for leeway in (3, timedelta(seconds=3)):
320-
jwt.decode(
339+
decoded = jwt.decode(
321340
jwt_message, secret, leeway=leeway, algorithms=["HS256"]
322341
)
342+
assert decoded == payload
323343

324344
# With 1 seconds, should fail
325345
for leeway in (1, timedelta(seconds=1)):

0 commit comments

Comments
 (0)