Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import timedelta
from unittest.mock import ANY, Mock
from unittest.mock import ANY, Mock, patch

from faker import Faker
from flask import json
Expand All @@ -16,6 +16,9 @@
datetime_str,
)
from utils import worked_time
from time_tracker_api.time_entries.time_entries_model import (
TimeEntryCosmosDBModel,
)

from werkzeug.exceptions import NotFound, UnprocessableEntity, HTTPException

Expand Down Expand Up @@ -189,6 +192,81 @@ def test_get_time_entry_should_succeed_with_valid_id(
dao_get_all_mock.assert_called_once()


@patch(
'time_tracker_api.time_entries.time_entries_dao.TimeEntriesCosmosDBDao.create_event_context',
Mock(),
)
@patch(
'time_tracker_api.time_entries.time_entries_dao.TimeEntriesCosmosDBDao.build_custom_query',
Mock(),
)
@patch(
'time_tracker_api.time_entries.time_entries_dao.TimeEntriesCosmosDBDao.handle_date_filter_args',
Mock(),
)
@patch('msal.ConfidentialClientApplication', Mock())
@patch('utils.azure_users.AzureConnection.get_token', Mock())
@patch('utils.azure_users.AzureConnection.is_test_user')
@patch('utils.azure_users.AzureConnection.get_test_user_ids')
@pytest.mark.parametrize(
'current_user_is_tester, expected_user_ids',
[
(True, ['id1', 'id2']),
(False, ['id2']),
],
)
def test_get_time_entries_by_type_of_user(
get_test_user_ids_mock,
is_test_user_mock,
client: FlaskClient,
valid_header: dict,
time_entries_dao,
current_user_is_tester,
expected_user_ids,
):
test_user_id = "id1"
non_test_user_id = "id2"
te1 = TimeEntryCosmosDBModel(
{
"id": '1',
"project_id": "1",
"owner_id": test_user_id,
"tenant_id": '1',
"start_date": "",
}
)
te2 = TimeEntryCosmosDBModel(
{
"id": '2',
"project_id": "2",
"owner_id": non_test_user_id,
"tenant_id": '2',
"start_date": "",
}
)

find_all_mock = Mock()
find_all_mock.return_value = [te1, te2]

time_entries_dao.repository.find_all = find_all_mock

is_test_user_mock.return_value = current_user_is_tester
get_test_user_ids_mock.return_value = [test_user_id]

response = client.get(
"/time-entries?user_id=*", headers=valid_header, follow_redirects=True
)

is_test_user_mock.assert_called()
find_all_mock.assert_called()

expected_user_ids_in_time_entries = expected_user_ids
actual_user_ids_in_time_entries = [
time_entry["owner_id"] for time_entry in json.loads(response.data)
]
assert expected_user_ids_in_time_entries == actual_user_ids_in_time_entries


def test_get_time_entry_should_succeed_with_valid_id(
client: FlaskClient,
mocker: MockFixture,
Expand Down Expand Up @@ -595,6 +673,11 @@ def test_create_with_valid_uuid_format_should_return_created(
repository_container_create_item_mock.assert_called()


@patch('msal.ConfidentialClientApplication', Mock())
@patch('utils.azure_users.AzureConnection.get_token', Mock())
@patch(
'utils.azure_users.AzureConnection.is_test_user', Mock(return_value=True)
)
@pytest.mark.parametrize(
'url',
[
Expand Down Expand Up @@ -624,6 +707,11 @@ def test_get_all_passes_date_range_built_from_params_to_find_all(
assert 'end_date' in kwargs['date_range']


@patch('msal.ConfidentialClientApplication', Mock())
@patch('utils.azure_users.AzureConnection.get_token', Mock())
@patch(
'utils.azure_users.AzureConnection.is_test_user', Mock(return_value=True)
)
@pytest.mark.parametrize(
'url,start_date,end_date',
[
Expand Down Expand Up @@ -660,6 +748,11 @@ def test_get_all_passes_date_range_to_find_all_with_default_tz_offset(
assert kwargs['date_range']['end_date'] == end_date


@patch('msal.ConfidentialClientApplication', Mock())
@patch('utils.azure_users.AzureConnection.get_token', Mock())
@patch(
'utils.azure_users.AzureConnection.is_test_user', Mock(return_value=True)
)
@pytest.mark.parametrize(
'url,start_date,end_date',
[
Expand Down
44 changes: 44 additions & 0 deletions tests/utils/azure_users_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from unittest.mock import Mock, patch
from utils.azure_users import AzureConnection, ROLE_FIELD_VALUES, AzureUser_v2
from pytest import mark


@patch('msal.ConfidentialClientApplication')
@patch('utils.azure_users.AzureConnection.get_token')
@patch('utils.azure_users.AzureConnection._get_user')
@mark.parametrize(
'field_name,field_value,expected',
[
(ROLE_FIELD_VALUES['test'][0], ROLE_FIELD_VALUES['test'][1], True),
(ROLE_FIELD_VALUES['test'][0], None, False),
],
)
def test_azure_connection_is_test_user(
_get_user_mock,
get_token_mock,
msal_client_mock,
field_name,
field_value,
expected,
):
_get_user_mock.return_value = {field_name: field_value}
test_user_id = 'test-user-id'
az_conn = AzureConnection()
assert az_conn.is_test_user(test_user_id) == expected


@patch('msal.ConfidentialClientApplication')
@patch('utils.azure_users.AzureConnection.get_token')
@patch('utils.azure_users.AzureConnection._get_test_user_ids')
def test_azure_connection_get_test_user_ids(
_get_test_user_ids_mock,
get_token_mock,
msal_client_mock,
):
_get_test_user_ids_mock.return_value = [
{'objectId': 'ID1'},
{'objectId': 'ID2'},
]
ids = ['ID1', 'ID2']
az_conn = AzureConnection()
assert az_conn.get_test_user_ids() == ids
19 changes: 17 additions & 2 deletions time_tracker_api/time_entries/time_entries_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from time_tracker_api.database import CRUDDao, APICosmosDBDao
from time_tracker_api.security import current_user_id
from utils.azure_users import AzureConnection


class TimeEntriesDao(CRUDDao):
Expand Down Expand Up @@ -93,21 +94,35 @@ def build_custom_query(self, is_admin: bool, conditions: dict = None):
def get_all(self, conditions: dict = None, **kwargs) -> list:
event_ctx = self.create_event_context("read-many")
conditions.update({"owner_id": event_ctx.user_id})

is_complete_query = conditions.get("user_id") == '*'
custom_query = self.build_custom_query(
is_admin=event_ctx.is_admin,
conditions=conditions,
)
date_range = self.handle_date_filter_args(args=conditions)
limit = conditions.get("limit", None)
conditions.pop("limit", None)
return self.repository.find_all(
azure_connection = AzureConnection()
current_user_is_tester = azure_connection.is_test_user(
event_ctx.user_id
)
time_entries_list = self.repository.find_all(
event_ctx,
conditions=conditions,
custom_sql_conditions=custom_query,
date_range=date_range,
max_count=limit,
)
if not current_user_is_tester and is_complete_query:
test_user_ids = azure_connection.get_test_user_ids()
time_entries_list = [
time_entry
for time_entry in time_entries_list
if time_entry.owner_id not in test_user_ids
]
return time_entries_list
else:
return time_entries_list

def get_lastest_entries_by_project(
self, conditions: dict = None, **kwargs
Expand Down
32 changes: 31 additions & 1 deletion utils/azure_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def get_token(self):

def users(self) -> List[AzureUser]:
endpoint = "{endpoint}/users?api-version=1.6&$select=displayName,otherMails,objectId,{role_field}".format(
endpoint=self.config.ENDPOINT, role_field=self.role_field,
endpoint=self.config.ENDPOINT,
role_field=self.role_field,
)
response = requests.get(endpoint, auth=BearerAuth(self.access_token))

Expand Down Expand Up @@ -185,3 +186,32 @@ def get_role_data(self, role_id, is_grant=True):
return {field_name: field_value}
else:
return {field_name: None}

def _get_user(self, user_id):
endpoint = "{endpoint}/users/{user_id}?api-version=1.6".format(
endpoint=self.config.ENDPOINT, user_id=user_id
)
response = requests.get(endpoint, auth=BearerAuth(self.access_token))
assert 200 == response.status_code
return response.json()

def is_test_user(self, user_id):
response = self._get_user(user_id)
field_name, field_value = ROLE_FIELD_VALUES['test']
return field_name in response and field_value == response[field_name]

def _get_test_user_ids(self):
field_name, field_value = ROLE_FIELD_VALUES['test']
endpoint = "{endpoint}/users?api-version=1.6&$select=objectId,{field_name}&$filter={field_name} eq '{field_value}'".format(
endpoint=self.config.ENDPOINT,
field_name=field_name,
field_value=field_value,
)
response = requests.get(endpoint, auth=BearerAuth(self.access_token))
assert 200 == response.status_code
assert 'value' in response.json()
return response.json()['value']

def get_test_user_ids(self):
response = self._get_test_user_ids()
return [item['objectId'] for item in response]