|
1 | 1 | import msal |
2 | 2 | import os |
3 | 3 | import requests |
| 4 | +from typing import List |
4 | 5 |
|
5 | 6 |
|
6 | | -class MSALConfig: |
7 | | - MSAL_CLIENT_ID = os.environ.get('MSAL_CLIENT_ID') |
8 | | - MSAL_AUTHORITY = os.environ.get('MSAL_AUTHORITY') |
9 | | - MSAL_SECRET = os.environ.get('MSAL_SECRET') |
10 | | - MSAL_SCOPE = os.environ.get('MSAL_SCOPE') |
11 | | - MSAL_ENDPOINT = os.environ.get('MSAL_ENDPOINT') |
12 | | - """ |
13 | | - TODO : Add validation to ensure variables are set |
14 | | - """ |
| 7 | +class MSConfig: |
| 8 | + def check_variables_are_defined(): |
| 9 | + auth_variables = [ |
| 10 | + 'MS_CLIENT_ID', |
| 11 | + 'MS_AUTHORITY', |
| 12 | + 'MS_SECRET', |
| 13 | + 'MS_SCOPE', |
| 14 | + 'MS_ENDPOINT', |
| 15 | + ] |
| 16 | + for var in auth_variables: |
| 17 | + if var not in os.environ: |
| 18 | + raise EnvironmentError( |
| 19 | + "{} is not defined in the environment".format(var) |
| 20 | + ) |
| 21 | + |
| 22 | + check_variables_are_defined() |
| 23 | + CLIENT_ID = os.environ.get('MS_CLIENT_ID') |
| 24 | + AUTHORITY = os.environ.get('MS_AUTHORITY') |
| 25 | + SECRET = os.environ.get('MS_SECRET') |
| 26 | + SCOPE = os.environ.get('MS_SCOPE') |
| 27 | + ENDPOINT = os.environ.get('MS_ENDPOINT') |
| 28 | + |
| 29 | + |
| 30 | +class BearerAuth(requests.auth.AuthBase): |
| 31 | + def __init__(self, access_token): |
| 32 | + self.access_token = access_token |
| 33 | + |
| 34 | + def __call__(self, r): |
| 35 | + r.headers["Authorization"] = f'Bearer {self.access_token}' |
| 36 | + return r |
15 | 37 |
|
16 | 38 |
|
17 | 39 | class AzureUser: |
18 | | - def __init__(self, id, display_name, email): |
| 40 | + def __init__(self, id, name, email): |
19 | 41 | self.id = id |
20 | | - self.display_name = display_name |
| 42 | + self.name = name |
21 | 43 | self.email = email |
22 | 44 |
|
23 | 45 |
|
24 | | -class AzureUsers: |
25 | | - def __init__(self, config=MSALConfig): |
| 46 | +class AzureConnection: |
| 47 | + def __init__(self, config=MSConfig): |
26 | 48 | self.client = msal.ConfidentialClientApplication( |
27 | | - config.MSAL_CLIENT_ID, |
28 | | - authority=config.MSAL_AUTHORITY, |
29 | | - client_credential=config.MSAL_SECRET, |
| 49 | + config.CLIENT_ID, |
| 50 | + authority=config.AUTHORITY, |
| 51 | + client_credential=config.SECRET, |
30 | 52 | ) |
31 | 53 | self.config = config |
32 | 54 | self.set_token() |
33 | 55 |
|
34 | 56 | def set_token(self): |
35 | 57 | response = self.client.acquire_token_for_client( |
36 | | - scopes=self.config.MSAL_SCOPE |
| 58 | + scopes=self.config.SCOPE |
37 | 59 | ) |
38 | 60 | if "access_token" in response: |
39 | | - # Call a protected API with the access token. |
40 | | - # print(response["access_token"]) |
41 | 61 | self.access_token = response['access_token'] |
42 | 62 | else: |
43 | | - print(response.get("error")) |
44 | | - print(response.get("error_description")) |
45 | | - print( |
46 | | - response.get("correlation_id") |
47 | | - ) # You might need this when reporting a bug |
48 | | - |
49 | | - def get_user_info_by_id(self, id): |
50 | | - endpoint = f"{self.config.MSAL_ENDPOINT}/users/{id}?api-version=1.6&$select=displayName,otherMails" |
51 | | - # print(endpoint) |
52 | | - http_headers = { |
53 | | - 'Authorization': f'Bearer {self.access_token}', |
54 | | - 'Accept': 'application/json', |
55 | | - 'Content-Type': 'application/json', |
56 | | - } |
57 | | - data = requests.get( |
58 | | - endpoint, headers=http_headers, stream=False |
59 | | - ).json() |
60 | | - return data |
| 63 | + error_info = f"{response['error']} {response['error_description']}" |
| 64 | + raise ValueError(error_info) |
61 | 65 |
|
62 | | - def get_users_info(self): |
63 | | - endpoint = f"{self.config.MSAL_ENDPOINT}/users?api-version=1.6&$select=displayName,otherMails,objectId" |
64 | | - http_headers = { |
65 | | - 'Authorization': f'Bearer {self.access_token}', |
66 | | - 'Accept': 'application/json', |
67 | | - 'Content-Type': 'application/json', |
68 | | - } |
69 | | - data = requests.get( |
70 | | - endpoint, headers=http_headers, stream=False |
71 | | - ).json() |
72 | | - return data |
| 66 | + def users(self) -> List[AzureUser]: |
| 67 | + def to_azure_user(item) -> AzureUser: |
| 68 | + there_is_email = len(item['otherMails']) > 0 |
| 69 | + id = item['objectId'] |
| 70 | + name = item['displayName'] |
| 71 | + email = item['otherMails'][0] if there_is_email else '' |
| 72 | + return AzureUser(id, name, email) |
73 | 73 |
|
74 | | - def users(self): |
75 | | - endpoint = f"{self.config.MSAL_ENDPOINT}/users?api-version=1.6&$select=displayName,otherMails,objectId" |
76 | | - http_headers = { |
77 | | - 'Authorization': f'Bearer {self.access_token}', |
78 | | - #'Accept': 'application/json', |
79 | | - #'Content-Type': 'application/json', |
80 | | - } |
81 | | - data = requests.get( |
82 | | - endpoint, headers=http_headers, stream=False |
83 | | - ).json() |
84 | | - # print(data) |
| 74 | + endpoint = f"{self.config.ENDPOINT}/users?api-version=1.6&$select=displayName,otherMails,objectId" |
| 75 | + response = requests.get(endpoint, auth=BearerAuth(self.access_token)) |
85 | 76 |
|
86 | | - users = [] |
87 | | - for value in data['value']: |
88 | | - user = AzureUser( |
89 | | - id=value['objectId'], |
90 | | - display_name=value['displayName'], |
91 | | - email=value['otherMails'][0], |
92 | | - ) |
93 | | - users.append(user) |
94 | | - return users |
| 77 | + assert 200 == response.status_code |
| 78 | + assert 'value' in response.json() |
| 79 | + return [to_azure_user(item) for item in response.json()['value']] |
0 commit comments