|
1 | 1 | import msal
|
2 | 2 | import os
|
3 | 3 | import requests
|
| 4 | +from typing import List |
4 | 5 |
|
5 | 6 |
|
6 |
| -class MSALConfig: |
7 |
| - MSAL_CLIENT_ID = os.environ.get('MSAL_CLIENT_ID') |
8 |
| - MSAL_AUTHORITY = os.environ.get('MSAL_AUTHORITY') |
9 |
| - MSAL_SECRET = os.environ.get('MSAL_SECRET') |
10 |
| - MSAL_SCOPE = os.environ.get('MSAL_SCOPE') |
11 |
| - MSAL_ENDPOINT = os.environ.get('MSAL_ENDPOINT') |
12 |
| - """ |
13 |
| - TODO : Add validation to ensure variables are set |
14 |
| - """ |
| 7 | +class MSConfig: |
| 8 | + def check_variables_are_defined(): |
| 9 | + auth_variables = [ |
| 10 | + 'MS_CLIENT_ID', |
| 11 | + 'MS_AUTHORITY', |
| 12 | + 'MS_SECRET', |
| 13 | + 'MS_SCOPE', |
| 14 | + 'MS_ENDPOINT', |
| 15 | + ] |
| 16 | + for var in auth_variables: |
| 17 | + if var not in os.environ: |
| 18 | + raise EnvironmentError( |
| 19 | + "{} is not defined in the environment".format(var) |
| 20 | + ) |
| 21 | + |
| 22 | + check_variables_are_defined() |
| 23 | + CLIENT_ID = os.environ.get('MS_CLIENT_ID') |
| 24 | + AUTHORITY = os.environ.get('MS_AUTHORITY') |
| 25 | + SECRET = os.environ.get('MS_SECRET') |
| 26 | + SCOPE = os.environ.get('MS_SCOPE') |
| 27 | + ENDPOINT = os.environ.get('MS_ENDPOINT') |
| 28 | + |
| 29 | + |
| 30 | +class BearerAuth(requests.auth.AuthBase): |
| 31 | + def __init__(self, access_token): |
| 32 | + self.access_token = access_token |
| 33 | + |
| 34 | + def __call__(self, r): |
| 35 | + r.headers["Authorization"] = f'Bearer {self.access_token}' |
| 36 | + return r |
15 | 37 |
|
16 | 38 |
|
17 | 39 | class AzureUser:
|
18 |
| - def __init__(self, id, display_name, email): |
| 40 | + def __init__(self, id, name, email): |
19 | 41 | self.id = id
|
20 |
| - self.display_name = display_name |
| 42 | + self.name = name |
21 | 43 | self.email = email
|
22 | 44 |
|
23 | 45 |
|
24 |
| -class AzureUsers: |
25 |
| - def __init__(self, config=MSALConfig): |
| 46 | +class AzureConnection: |
| 47 | + def __init__(self, config=MSConfig): |
26 | 48 | self.client = msal.ConfidentialClientApplication(
|
27 |
| - config.MSAL_CLIENT_ID, |
28 |
| - authority=config.MSAL_AUTHORITY, |
29 |
| - client_credential=config.MSAL_SECRET, |
| 49 | + config.CLIENT_ID, |
| 50 | + authority=config.AUTHORITY, |
| 51 | + client_credential=config.SECRET, |
30 | 52 | )
|
31 | 53 | self.config = config
|
32 |
| - self.set_token() |
| 54 | + self.access_token = self.get_token() |
33 | 55 |
|
34 |
| - def set_token(self): |
| 56 | + def get_token(self): |
35 | 57 | response = self.client.acquire_token_for_client(
|
36 |
| - scopes=self.config.MSAL_SCOPE |
| 58 | + scopes=self.config.SCOPE |
37 | 59 | )
|
38 | 60 | if "access_token" in response:
|
39 |
| - # Call a protected API with the access token. |
40 |
| - # print(response["access_token"]) |
41 |
| - self.access_token = response['access_token'] |
| 61 | + return response['access_token'] |
42 | 62 | else:
|
43 |
| - print(response.get("error")) |
44 |
| - print(response.get("error_description")) |
45 |
| - print( |
46 |
| - response.get("correlation_id") |
47 |
| - ) # You might need this when reporting a bug |
48 |
| - |
49 |
| - def get_user_info_by_id(self, id): |
50 |
| - endpoint = f"{self.config.MSAL_ENDPOINT}/users/{id}?api-version=1.6&$select=displayName,otherMails" |
51 |
| - # print(endpoint) |
52 |
| - http_headers = { |
53 |
| - 'Authorization': f'Bearer {self.access_token}', |
54 |
| - 'Accept': 'application/json', |
55 |
| - 'Content-Type': 'application/json', |
56 |
| - } |
57 |
| - data = requests.get( |
58 |
| - endpoint, headers=http_headers, stream=False |
59 |
| - ).json() |
60 |
| - return data |
| 63 | + error_info = f"{response['error']} {response['error_description']}" |
| 64 | + raise ValueError(error_info) |
61 | 65 |
|
62 |
| - def get_users_info(self): |
63 |
| - endpoint = f"{self.config.MSAL_ENDPOINT}/users?api-version=1.6&$select=displayName,otherMails,objectId" |
64 |
| - http_headers = { |
65 |
| - 'Authorization': f'Bearer {self.access_token}', |
66 |
| - 'Accept': 'application/json', |
67 |
| - 'Content-Type': 'application/json', |
68 |
| - } |
69 |
| - data = requests.get( |
70 |
| - endpoint, headers=http_headers, stream=False |
71 |
| - ).json() |
72 |
| - return data |
| 66 | + def users(self) -> List[AzureUser]: |
| 67 | + def to_azure_user(item) -> AzureUser: |
| 68 | + there_is_email = len(item['otherMails']) > 0 |
| 69 | + id = item['objectId'] |
| 70 | + name = item['displayName'] |
| 71 | + email = item['otherMails'][0] if there_is_email else '' |
| 72 | + return AzureUser(id, name, email) |
73 | 73 |
|
74 |
| - def users(self): |
75 |
| - endpoint = f"{self.config.MSAL_ENDPOINT}/users?api-version=1.6&$select=displayName,otherMails,objectId" |
76 |
| - http_headers = { |
77 |
| - 'Authorization': f'Bearer {self.access_token}', |
78 |
| - #'Accept': 'application/json', |
79 |
| - #'Content-Type': 'application/json', |
80 |
| - } |
81 |
| - data = requests.get( |
82 |
| - endpoint, headers=http_headers, stream=False |
83 |
| - ).json() |
84 |
| - # print(data) |
| 74 | + endpoint = f"{self.config.ENDPOINT}/users?api-version=1.6&$select=displayName,otherMails,objectId" |
| 75 | + response = requests.get(endpoint, auth=BearerAuth(self.access_token)) |
85 | 76 |
|
86 |
| - users = [] |
87 |
| - for value in data['value']: |
88 |
| - user = AzureUser( |
89 |
| - id=value['objectId'], |
90 |
| - display_name=value['displayName'], |
91 |
| - email=value['otherMails'][0], |
92 |
| - ) |
93 |
| - users.append(user) |
94 |
| - return users |
| 77 | + assert 200 == response.status_code |
| 78 | + assert 'value' in response.json() |
| 79 | + return [to_azure_user(item) for item in response.json()['value']] |
0 commit comments