From 74e9eeaa07caa4037b427a1cb114d548493c578e Mon Sep 17 00:00:00 2001 From: roberto Date: Wed, 3 Jun 2020 15:42:55 -0500 Subject: [PATCH] feat: make a function to get user's info --- .env.template | 10 +- .../time_entries/time_entries_model.py | 5 +- utils/azure_users.py | 127 ++++++++---------- 3 files changed, 64 insertions(+), 78 deletions(-) diff --git a/.env.template b/.env.template index 9d230a6e..d5bfb21c 100644 --- a/.env.template +++ b/.env.template @@ -15,8 +15,8 @@ export DATABASE_MASTER_KEY= export DATABASE_NAME= ## For Azure Users interaction -export MSAL_AUTHORITY= -export MSAL_CLIENT_ID= -export MSAL_SCOPE= -export MSAL_SECRET= -export MSAL_ENDPOINT= +export MS_AUTHORITY= +export MS_CLIENT_ID= +export MS_SCOPE= +export MS_SECRET= +export MS_ENDPOINT= diff --git a/time_tracker_api/time_entries/time_entries_model.py b/time_tracker_api/time_entries/time_entries_model.py index 82ebdbc8..1d4c0eac 100644 --- a/time_tracker_api/time_entries/time_entries_model.py +++ b/time_tracker_api/time_entries/time_entries_model.py @@ -31,7 +31,7 @@ from utils import worked_time from utils.worked_time import str_to_datetime -from utils.azure_users import AzureUsers +from utils.azure_users import AzureConnection from time_tracker_api.projects.projects_model import ProjectCosmosDBModel from time_tracker_api.projects import projects_model from time_tracker_api.database import CRUDDao, APICosmosDBDao @@ -179,7 +179,8 @@ def find_all( activities = activity_dao.get_all() add_activity_name_to_time_entries(time_entries, activities) - add_user_email_to_time_entries(time_entries, AzureUsers().users()) + users = AzureConnection().users() + add_user_email_to_time_entries(time_entries, users) return time_entries def on_create(self, new_item_data: dict, event_context: EventContext): diff --git a/utils/azure_users.py b/utils/azure_users.py index d57d56d0..b31f0f0e 100644 --- a/utils/azure_users.py +++ b/utils/azure_users.py @@ -1,94 +1,79 @@ import msal import os import requests +from typing import List -class MSALConfig: - MSAL_CLIENT_ID = os.environ.get('MSAL_CLIENT_ID') - MSAL_AUTHORITY = os.environ.get('MSAL_AUTHORITY') - MSAL_SECRET = os.environ.get('MSAL_SECRET') - MSAL_SCOPE = os.environ.get('MSAL_SCOPE') - MSAL_ENDPOINT = os.environ.get('MSAL_ENDPOINT') - """ - TODO : Add validation to ensure variables are set - """ +class MSConfig: + def check_variables_are_defined(): + auth_variables = [ + 'MS_CLIENT_ID', + 'MS_AUTHORITY', + 'MS_SECRET', + 'MS_SCOPE', + 'MS_ENDPOINT', + ] + for var in auth_variables: + if var not in os.environ: + raise EnvironmentError( + "{} is not defined in the environment".format(var) + ) + + check_variables_are_defined() + CLIENT_ID = os.environ.get('MS_CLIENT_ID') + AUTHORITY = os.environ.get('MS_AUTHORITY') + SECRET = os.environ.get('MS_SECRET') + SCOPE = os.environ.get('MS_SCOPE') + ENDPOINT = os.environ.get('MS_ENDPOINT') + + +class BearerAuth(requests.auth.AuthBase): + def __init__(self, access_token): + self.access_token = access_token + + def __call__(self, r): + r.headers["Authorization"] = f'Bearer {self.access_token}' + return r class AzureUser: - def __init__(self, id, display_name, email): + def __init__(self, id, name, email): self.id = id - self.display_name = display_name + self.name = name self.email = email -class AzureUsers: - def __init__(self, config=MSALConfig): +class AzureConnection: + def __init__(self, config=MSConfig): self.client = msal.ConfidentialClientApplication( - config.MSAL_CLIENT_ID, - authority=config.MSAL_AUTHORITY, - client_credential=config.MSAL_SECRET, + config.CLIENT_ID, + authority=config.AUTHORITY, + client_credential=config.SECRET, ) self.config = config - self.set_token() + self.access_token = self.get_token() - def set_token(self): + def get_token(self): response = self.client.acquire_token_for_client( - scopes=self.config.MSAL_SCOPE + scopes=self.config.SCOPE ) if "access_token" in response: - # Call a protected API with the access token. - # print(response["access_token"]) - self.access_token = response['access_token'] + return response['access_token'] else: - print(response.get("error")) - print(response.get("error_description")) - print( - response.get("correlation_id") - ) # You might need this when reporting a bug - - def get_user_info_by_id(self, id): - endpoint = f"{self.config.MSAL_ENDPOINT}/users/{id}?api-version=1.6&$select=displayName,otherMails" - # print(endpoint) - http_headers = { - 'Authorization': f'Bearer {self.access_token}', - 'Accept': 'application/json', - 'Content-Type': 'application/json', - } - data = requests.get( - endpoint, headers=http_headers, stream=False - ).json() - return data + error_info = f"{response['error']} {response['error_description']}" + raise ValueError(error_info) - def get_users_info(self): - endpoint = f"{self.config.MSAL_ENDPOINT}/users?api-version=1.6&$select=displayName,otherMails,objectId" - http_headers = { - 'Authorization': f'Bearer {self.access_token}', - 'Accept': 'application/json', - 'Content-Type': 'application/json', - } - data = requests.get( - endpoint, headers=http_headers, stream=False - ).json() - return data + def users(self) -> List[AzureUser]: + def to_azure_user(item) -> AzureUser: + there_is_email = len(item['otherMails']) > 0 + id = item['objectId'] + name = item['displayName'] + email = item['otherMails'][0] if there_is_email else '' + return AzureUser(id, name, email) - def users(self): - endpoint = f"{self.config.MSAL_ENDPOINT}/users?api-version=1.6&$select=displayName,otherMails,objectId" - http_headers = { - 'Authorization': f'Bearer {self.access_token}', - #'Accept': 'application/json', - #'Content-Type': 'application/json', - } - data = requests.get( - endpoint, headers=http_headers, stream=False - ).json() - # print(data) + endpoint = f"{self.config.ENDPOINT}/users?api-version=1.6&$select=displayName,otherMails,objectId" + response = requests.get(endpoint, auth=BearerAuth(self.access_token)) - users = [] - for value in data['value']: - user = AzureUser( - id=value['objectId'], - display_name=value['displayName'], - email=value['otherMails'][0], - ) - users.append(user) - return users + assert 200 == response.status_code + assert 'value' in response.json() + return [to_azure_user(item) for item in response.json()['value']]