|
1 | 1 | import json |
2 | 2 | import warnings |
3 | | - |
4 | 3 | from calendar import timegm |
5 | | -from collections import Mapping |
| 4 | +from collections import Iterable, Mapping |
6 | 5 | from datetime import datetime, timedelta |
7 | 6 |
|
8 | 7 | from .api_jws import PyJWS |
@@ -103,8 +102,8 @@ def _validate_claims(self, payload, options, audience=None, issuer=None, |
103 | 102 | if isinstance(leeway, timedelta): |
104 | 103 | leeway = leeway.total_seconds() |
105 | 104 |
|
106 | | - if not isinstance(audience, (string_types, type(None))): |
107 | | - raise TypeError('audience must be a string or None') |
| 105 | + if not isinstance(audience, (string_types, type(None), Iterable)): |
| 106 | + raise TypeError('audience must be a string, iterable, or None') |
108 | 107 |
|
109 | 108 | self._validate_required_claims(payload, options) |
110 | 109 |
|
@@ -177,7 +176,11 @@ def _validate_aud(self, payload, audience): |
177 | 176 | raise InvalidAudienceError('Invalid claim format in token') |
178 | 177 | if any(not isinstance(c, string_types) for c in audience_claims): |
179 | 178 | raise InvalidAudienceError('Invalid claim format in token') |
180 | | - if audience not in audience_claims: |
| 179 | + |
| 180 | + if isinstance(audience, string_types): |
| 181 | + audience = [audience] |
| 182 | + |
| 183 | + if not any(aud in audience_claims for aud in audience): |
181 | 184 | raise InvalidAudienceError('Invalid audience') |
182 | 185 |
|
183 | 186 | def _validate_iss(self, payload, issuer): |
|
0 commit comments