Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions tests/utils/azure_users_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
],
)
def test_azure_connection_is_test_user(
get_mock,
field_name,
field_value,
is_test_user_expected_value,
get_mock,
field_name,
field_value,
is_test_user_expected_value,
):
response_mock = Mock()
response_mock.status_code = 200
Expand Down Expand Up @@ -58,7 +58,7 @@ def test_azure_connection_get_test_user_ids(get_mock):
@patch('utils.azure_users.AzureConnection.get_test_user_ids')
@patch('utils.azure_users.AzureConnection.users')
def test_azure_connection_get_non_test_users(
users_mock, get_test_user_ids_mock
users_mock, get_test_user_ids_mock
):
test_user = AzureUser('ID1', None, None, [], [])
non_test_user = AzureUser('ID2', None, None, [], [])
Expand All @@ -81,7 +81,7 @@ def test_azure_connection_get_group_id_by_group_name(get_mock):
group_id = 'ID1'
azure_connection = AzureConnection()
assert (
azure_connection.get_group_id_by_group_name('group_name') == group_id
azure_connection.get_group_id_by_group_name('group_name') == group_id
)


Expand All @@ -91,7 +91,7 @@ def test_azure_connection_get_group_id_by_group_name(get_mock):
@patch('requests.post')
@mark.parametrize('expected_value', [True, False])
def test_is_user_in_group(
post_mock, get_group_id_by_group_name_mock, expected_value
post_mock, get_group_id_by_group_name_mock, expected_value
):
response_expected = {'value': expected_value}
response_mock = Mock()
Expand All @@ -104,8 +104,8 @@ def test_is_user_in_group(

azure_connection = AzureConnection()
assert (
azure_connection.is_user_in_group('user_id', payload_mock)
== response_expected
azure_connection.is_user_in_group('user_id', payload_mock)
== response_expected
)


Expand Down Expand Up @@ -164,7 +164,7 @@ def test_get_groups_and_users(get_mock):
],
)
def test_get_groups_by_user_id(
get_groups_and_users_mock, user_id, groups_expected_value
get_groups_and_users_mock, user_id, groups_expected_value
):
get_groups_and_users_mock.return_value = [
('test-group-1', ['user-id1', 'user-id2']),
Expand All @@ -180,7 +180,7 @@ def test_get_groups_by_user_id(
@patch('utils.azure_users.AzureConnection.get_token', Mock())
@patch('utils.azure_users.AzureConnection.get_groups_and_users')
def test_get_groups_and_users_called_once_by_instance(
get_groups_and_users_mock,
get_groups_and_users_mock,
):
get_groups_and_users_mock.return_value = []
user_id = 'user-id1'
Expand All @@ -198,7 +198,7 @@ def test_get_groups_and_users_called_once_by_instance(
@patch('utils.azure_users.AzureConnection.get_group_id_by_group_name')
@patch('requests.post')
def test_add_user_to_group(
post_mock, get_group_id_by_group_name_mock, get_user_mock
post_mock, get_group_id_by_group_name_mock, get_user_mock
):
get_group_id_by_group_name_mock.return_value = 'dummy_group'
test_user = AzureUser('ID1', None, None, [], [])
Expand All @@ -224,7 +224,7 @@ def test_add_user_to_group(
@patch('utils.azure_users.AzureConnection.get_group_id_by_group_name')
@patch('requests.delete')
def test_remove_user_from_group(
delete_mock, get_group_id_by_group_name_mock, get_user_mock
delete_mock, get_group_id_by_group_name_mock, get_user_mock
):
get_group_id_by_group_name_mock.return_value = 'dummy_group'
test_user = AzureUser('ID1', None, None, [], [])
Expand All @@ -247,7 +247,7 @@ def test_remove_user_from_group(
@patch('utils.azure_users.AzureConnection.get_groups_and_users')
@patch('requests.get')
def test_users_functions_should_returns_all_users(
get_mock, get_groups_and_users_mock
get_mock, get_groups_and_users_mock
):
first_response = Response()
first_response.status_code = 200
Expand Down
19 changes: 15 additions & 4 deletions utils/azure_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, config=MSConfig):
self.client = self.get_msal_client()
self.access_token = self.get_token()
self.groups_and_users = None

def get_blob_storage_connection_string(self) -> str:
return self.config.AZURE_STORAGE_CONNECTION_STRING

Expand Down Expand Up @@ -187,7 +187,15 @@ def add_user_to_group(self, user_id, group_name):
headers=HTTP_PATCH_HEADERS,
)
assert 204 == response.status_code

if self.groups_and_users is None:
self.groups_and_users = [(group_name, [user_id])]
elif group_name not in [gn for (gn, ul) in self.groups_and_users]:
self.groups_and_users.append((group_name, [user_id]))
else:
for (cache_group_name, user_ids) in self.groups_and_users:
if group_name == cache_group_name:
if user_id not in user_ids:
user_ids.append(user_id)
return self.get_user(user_id)

def remove_user_from_group(self, user_id, group_name):
Expand All @@ -201,7 +209,11 @@ def remove_user_from_group(self, user_id, group_name):
headers=HTTP_PATCH_HEADERS,
)
assert 204 == response.status_code

if self.groups_and_users is not None:
for (cache_group_name, user_ids) in self.groups_and_users:
if group_name == cache_group_name:
if user_id in user_ids:
user_ids.remove(user_id)
return self.get_user(user_id)

def get_non_test_users(self) -> List[AzureUser]:
Expand Down Expand Up @@ -271,7 +283,6 @@ def get_groups_and_users(self):
result = list(map(parse_item, response.json()['value']))
users_id = self.config.USERID.split(",")
result[0][1].extend(users_id)

return result

def is_user_in_group(self, user_id, data: dict):
Expand Down