Skip to content

Commit 8de4428

Browse files
authored
Prefer headers['alg'] to algorithm parameter in encode(). (jpadilla#673)
* Prefer headers['alg'] to algorithm parameter in encode(). * Fix lack of @crypto_required. * Prefer headers['alg'] to algorithm parameter in encode(). * Prefer headers['alg'] to algorithm parameter in encode(). * Make algorithm parameter of encode() Optioanl explicitly.
1 parent cfe5261 commit 8de4428

File tree

5 files changed

+36
-4
lines changed

5 files changed

+36
-4
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Changed
1313
Fixed
1414
~~~~~
1515

16+
- Prefer `headers["alg"]` to `algorithm` in `jwt.encode()`. `#673 <https://github.com/jpadilla/pyjwt/pull/673>`__
1617
- Fix aud validation to support {'aud': null} case. `#670 <https://github.com/jpadilla/pyjwt/pull/670>`__
1718

1819
Added

docs/api.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ API Reference
1313
* for **asymmetric algorithms**: PEM-formatted private key, a multiline string
1414
* for **symmetric algorithms**: plain string, sufficiently long for security
1515

16-
:param str algorithm: algorithm to sign the token with, e.g. ``"ES256"``
17-
:param dict headers: additional JWT header fields, e.g. ``dict(kid="my-key-id")``
16+
:param str algorithm: algorithm to sign the token with, e.g. ``"ES256"``.
17+
If ``headers`` includes ``alg``, it will be preferred to this parameter.
18+
:param dict headers: additional JWT header fields, e.g. ``dict(kid="my-key-id")``.
1819
:param json.JSONEncoder json_encoder: custom JSON encoder for ``payload`` and ``headers``
1920
:rtype: str
2021
:returns: a JSON Web Token

jwt/api_jws.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def encode(
7777
self,
7878
payload: bytes,
7979
key: str,
80-
algorithm: str = "HS256",
80+
algorithm: Optional[str] = "HS256",
8181
headers: Optional[Dict] = None,
8282
json_encoder: Optional[Type[json.JSONEncoder]] = None,
8383
) -> str:
@@ -86,6 +86,10 @@ def encode(
8686
if algorithm is None:
8787
algorithm = "none"
8888

89+
# Prefer headers["alg"] if present to algorithm parameter.
90+
if headers and "alg" in headers and headers["alg"]:
91+
algorithm = headers["alg"]
92+
8993
if algorithm not in self._valid_algs:
9094
pass
9195

jwt/api_jwt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def encode(
3838
self,
3939
payload: Dict[str, Any],
4040
key: str,
41-
algorithm: str = "HS256",
41+
algorithm: Optional[str] = "HS256",
4242
headers: Optional[Dict] = None,
4343
json_encoder: Optional[Type[json.JSONEncoder]] = None,
4444
) -> str:

tests/test_api_jws.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,32 @@ def test_encode_algorithm_param_should_be_case_sensitive(self, jws, payload):
166166
exception = context.value
167167
assert str(exception) == "Algorithm not supported"
168168

169+
def test_encode_with_headers_alg_none(self, jws, payload):
170+
msg = jws.encode(payload, key=None, headers={"alg": "none"})
171+
with pytest.raises(DecodeError) as context:
172+
jws.decode(msg, algorithms=["none"])
173+
assert str(context.value) == "Signature verification failed"
174+
175+
@crypto_required
176+
def test_encode_with_headers_alg_es256(self, jws, payload):
177+
with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file:
178+
priv_key = load_pem_private_key(ec_priv_file.read(), password=None)
179+
with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file:
180+
pub_key = load_pem_public_key(ec_pub_file.read())
181+
182+
msg = jws.encode(payload, priv_key, headers={"alg": "ES256"})
183+
assert b"hello world" == jws.decode(msg, pub_key, algorithms=["ES256"])
184+
185+
@crypto_required
186+
def test_encode_with_alg_hs256_and_headers_alg_es256(self, jws, payload):
187+
with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file:
188+
priv_key = load_pem_private_key(ec_priv_file.read(), password=None)
189+
with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file:
190+
pub_key = load_pem_public_key(ec_pub_file.read())
191+
192+
msg = jws.encode(payload, priv_key, algorithm="HS256", headers={"alg": "ES256"})
193+
assert b"hello world" == jws.decode(msg, pub_key, algorithms=["ES256"])
194+
169195
def test_decode_algorithm_param_should_be_case_sensitive(self, jws):
170196
example_jws = (
171197
"eyJhbGciOiJoczI1NiIsInR5cCI6IkpXVCJ9" # alg = hs256

0 commit comments

Comments
 (0)