diff --git a/commons/data_access_layer/cosmos_db.py b/commons/data_access_layer/cosmos_db.py index 8ac412c6..b8af2cbd 100644 --- a/commons/data_access_layer/cosmos_db.py +++ b/commons/data_access_layer/cosmos_db.py @@ -210,20 +210,25 @@ def find( self, id: str, event_context: EventContext, - peeker: Callable = None, visible_only=True, mapper: Callable = None, ): partition_key_value = self.find_partition_key_value(event_context) found_item = self.container.read_item(id, partition_key_value) - if peeker: - peeker(found_item) - function_mapper = self.get_mapper_or_dict(mapper) return function_mapper(self.check_visibility(found_item, visible_only)) - def find_all(self, event_context: EventContext, conditions: dict = {}, custom_sql_conditions: List[str] = [], - custom_params: dict = {}, max_count=None, offset=0, visible_only=True, mapper: Callable = None): + def find_all( + self, + event_context: EventContext, + conditions: dict = {}, + custom_sql_conditions: List[str] = [], + custom_params: dict = {}, + max_count=None, + offset=0, + visible_only=True, + mapper: Callable = None, + ): partition_key_value = self.find_partition_key_value(event_context) max_count = self.get_page_size_or(max_count) params = [ @@ -242,16 +247,16 @@ def find_all(self, event_context: EventContext, conditions: dict = {}, custom_sq {order_clause} OFFSET @offset LIMIT @max_count """.format( - partition_key_attribute=self.partition_key_attribute, - visibility_condition=self.create_sql_condition_for_visibility( - visible_only - ), - conditions_clause=self.create_sql_where_conditions(conditions), - custom_sql_conditions_clause=self.create_custom_sql_conditions( - custom_sql_conditions - ), - order_clause=self.create_sql_order_clause(), - ) + partition_key_attribute=self.partition_key_attribute, + visibility_condition=self.create_sql_condition_for_visibility( + visible_only + ), + conditions_clause=self.create_sql_where_conditions(conditions), + custom_sql_conditions_clause=self.create_custom_sql_conditions( + custom_sql_conditions + ), + order_clause=self.create_sql_order_clause(), + ) result = self.container.query_items( query=query_str, @@ -268,16 +273,11 @@ def partial_update( id: str, changes: dict, event_context: EventContext, - peeker: Callable = None, visible_only=True, mapper: Callable = None, ): item_data = self.find( - id, - event_context, - peeker=peeker, - visible_only=visible_only, - mapper=dict, + id, event_context, visible_only=visible_only, mapper=dict, ) item_data.update(changes) return self.update(id, item_data, event_context, mapper=mapper) @@ -295,17 +295,12 @@ def update( return function_mapper(self.container.replace_item(id, body=item_data)) def delete( - self, - id: str, - event_context: EventContext, - peeker: Callable = None, - mapper: Callable = None, + self, id: str, event_context: EventContext, mapper: Callable = None, ): return self.partial_update( id, {'deleted': generate_uuid4()}, event_context, - peeker=peeker, visible_only=True, mapper=mapper, ) diff --git a/tests/commons/data_access_layer/cosmos_db_test.py b/tests/commons/data_access_layer/cosmos_db_test.py index dd54038a..6cb6deb7 100644 --- a/tests/commons/data_access_layer/cosmos_db_test.py +++ b/tests/commons/data_access_layer/cosmos_db_test.py @@ -3,13 +3,21 @@ from typing import Callable import pytest -from azure.cosmos.exceptions import CosmosResourceExistsError, CosmosResourceNotFoundError +from azure.cosmos.exceptions import ( + CosmosResourceExistsError, + CosmosResourceNotFoundError, +) from faker import Faker from flask_restplus._http import HTTPStatus from pytest import fail -from commons.data_access_layer.cosmos_db import CosmosDBRepository, CosmosDBModel, CustomError, current_datetime, \ - datetime_str +from commons.data_access_layer.cosmos_db import ( + CosmosDBRepository, + CosmosDBModel, + CustomError, + current_datetime, + datetime_str, +) from commons.data_access_layer.database import EventContext fake = Faker() @@ -35,14 +43,18 @@ def test_repository_exists(cosmos_db_repository): assert cosmos_db_repository is not None -def test_create_should_succeed(cosmos_db_repository: CosmosDBRepository, - tenant_id: str, - event_context: EventContext): - sample_item = dict(id=fake.uuid4(), - name=fake.name(), - email=fake.safe_email(), - age=fake.pyint(min_value=10, max_value=80), - tenant_id=tenant_id) +def test_create_should_succeed( + cosmos_db_repository: CosmosDBRepository, + tenant_id: str, + event_context: EventContext, +): + sample_item = dict( + id=fake.uuid4(), + name=fake.name(), + email=fake.safe_email(), + age=fake.pyint(min_value=10, max_value=80), + tenant_id=tenant_id, + ) created_item = cosmos_db_repository.create(sample_item, event_context) @@ -50,9 +62,11 @@ def test_create_should_succeed(cosmos_db_repository: CosmosDBRepository, assert all(item in created_item.items() for item in sample_item.items()) -def test_create_should_fail_if_user_is_same(cosmos_db_repository: CosmosDBRepository, - sample_item: dict, - event_context: EventContext): +def test_create_should_fail_if_user_is_same( + cosmos_db_repository: CosmosDBRepository, + sample_item: dict, + event_context: EventContext, +): try: cosmos_db_repository.create(sample_item, event_context) @@ -62,27 +76,30 @@ def test_create_should_fail_if_user_is_same(cosmos_db_repository: CosmosDBReposi assert e.status_code == 409 -def test_create_with_diff_unique_data_but_same_tenant_should_succeed(cosmos_db_repository: CosmosDBRepository, - sample_item: dict, - event_context: EventContext): +def test_create_with_diff_unique_data_but_same_tenant_should_succeed( + cosmos_db_repository: CosmosDBRepository, + sample_item: dict, + event_context: EventContext, +): new_data = sample_item.copy() - new_data.update({ - 'id': fake.uuid4(), - 'email': fake.safe_email(), - }) + new_data.update( + {'id': fake.uuid4(), 'email': fake.safe_email(),} + ) result = cosmos_db_repository.create(new_data, event_context) assert result["id"] != sample_item["id"], 'It should be a new element' -def test_create_with_same_id_should_fail(cosmos_db_repository: CosmosDBRepository, - sample_item: dict, - event_context: EventContext): +def test_create_with_same_id_should_fail( + cosmos_db_repository: CosmosDBRepository, + sample_item: dict, + event_context: EventContext, +): try: new_data = sample_item.copy() - new_data.update({ - 'email': fake.safe_email(), - }) + new_data.update( + {'email': fake.safe_email(),} + ) cosmos_db_repository.create(new_data, event_context) @@ -92,14 +109,14 @@ def test_create_with_same_id_should_fail(cosmos_db_repository: CosmosDBRepositor assert e.status_code == 409 -def test_create_with_diff_id_but_same_unique_field_should_fail(cosmos_db_repository: CosmosDBRepository, - sample_item: dict, - event_context: EventContext): +def test_create_with_diff_id_but_same_unique_field_should_fail( + cosmos_db_repository: CosmosDBRepository, + sample_item: dict, + event_context: EventContext, +): try: new_data = sample_item.copy() - new_data.update({ - 'id': fake.uuid4() - }) + new_data.update({'id': fake.uuid4()}) cosmos_db_repository.create(new_data, event_context) @@ -109,47 +126,62 @@ def test_create_with_diff_id_but_same_unique_field_should_fail(cosmos_db_reposit assert e.status_code == 409 -def test_create_with_same_id_but_diff_partition_key_attrib_should_succeed(cosmos_db_repository: CosmosDBRepository, - another_event_context: EventContext, - sample_item: dict, - another_tenant_id: str): +def test_create_with_same_id_but_diff_partition_key_attrib_should_succeed( + cosmos_db_repository: CosmosDBRepository, + another_event_context: EventContext, + sample_item: dict, + another_tenant_id: str, +): new_data = sample_item.copy() - new_data.update({ - 'tenant_id': another_tenant_id, - }) + new_data.update( + {'tenant_id': another_tenant_id,} + ) result = cosmos_db_repository.create(new_data, another_event_context) assert result["id"] == sample_item["id"], "Should have allowed same id" -def test_create_with_mapper_should_provide_calculated_fields(cosmos_db_repository: CosmosDBRepository, - event_context: EventContext, - tenant_id: str): - new_item = dict(id=fake.uuid4(), - name=fake.name(), - email=fake.safe_email(), - age=fake.pyint(min_value=10, max_value=80), - tenant_id=tenant_id) +def test_create_with_mapper_should_provide_calculated_fields( + cosmos_db_repository: CosmosDBRepository, + event_context: EventContext, + tenant_id: str, +): + new_item = dict( + id=fake.uuid4(), + name=fake.name(), + email=fake.safe_email(), + age=fake.pyint(min_value=10, max_value=80), + tenant_id=tenant_id, + ) - created_item: Person = cosmos_db_repository.create(new_item, event_context, mapper=Person) + created_item: Person = cosmos_db_repository.create( + new_item, event_context, mapper=Person + ) assert created_item is not None - assert all(item in created_item.__dict__.items() for item in new_item.items()) - assert type(created_item) is Person, "The result should be wrapped with a class" + assert all( + item in created_item.__dict__.items() for item in new_item.items() + ) + assert ( + type(created_item) is Person + ), "The result should be wrapped with a class" assert created_item.is_adult() is (new_item["age"] >= 18) -def test_find_by_valid_id_should_succeed(cosmos_db_repository: CosmosDBRepository, - sample_item: dict, - event_context: EventContext): +def test_find_by_valid_id_should_succeed( + cosmos_db_repository: CosmosDBRepository, + sample_item: dict, + event_context: EventContext, +): found_item = cosmos_db_repository.find(sample_item["id"], event_context) assert all(item in found_item.items() for item in sample_item.items()) -def test_find_by_invalid_id_should_fail(cosmos_db_repository: CosmosDBRepository, - event_context: EventContext): +def test_find_by_invalid_id_should_fail( + cosmos_db_repository: CosmosDBRepository, event_context: EventContext +): try: cosmos_db_repository.find(fake.uuid4(), event_context) @@ -159,8 +191,9 @@ def test_find_by_invalid_id_should_fail(cosmos_db_repository: CosmosDBRepository assert e.status_code == 404 -def test_find_by_invalid_partition_key_value_should_fail(cosmos_db_repository: CosmosDBRepository, - event_context: EventContext): +def test_find_by_invalid_partition_key_value_should_fail( + cosmos_db_repository: CosmosDBRepository, event_context: EventContext +): try: cosmos_db_repository.find(fake.uuid4(), event_context) @@ -170,51 +203,75 @@ def test_find_by_invalid_partition_key_value_should_fail(cosmos_db_repository: C assert e.status_code == 404 -def test_find_by_valid_id_and_mapper_should_succeed(cosmos_db_repository: CosmosDBRepository, - sample_item: dict, - event_context: EventContext): - found_item: Person = cosmos_db_repository.find(sample_item["id"], - event_context, - mapper=Person) +def test_find_by_valid_id_and_mapper_should_succeed( + cosmos_db_repository: CosmosDBRepository, + sample_item: dict, + event_context: EventContext, +): + found_item: Person = cosmos_db_repository.find( + sample_item["id"], event_context, mapper=Person + ) found_item_dict = found_item.__dict__ - assert all(attrib in sample_item.items() for attrib in found_item_dict.items()) - assert type(found_item) is Person, "The result should be wrapped with a class" + assert all( + attrib in sample_item.items() for attrib in found_item_dict.items() + ) + assert ( + type(found_item) is Person + ), "The result should be wrapped with a class" assert found_item.is_adult() is (sample_item["age"] >= 18) @pytest.mark.parametrize( 'mapper,expected_type', [(None, dict), (dict, dict), (Person, Person)] ) -def test_find_all_with_mapper(cosmos_db_repository: CosmosDBRepository, - event_context: EventContext, - mapper: Callable, - expected_type: Callable): +def test_find_all_with_mapper( + cosmos_db_repository: CosmosDBRepository, + event_context: EventContext, + mapper: Callable, + expected_type: Callable, +): result = cosmos_db_repository.find_all(event_context, mapper=mapper) assert result is not None assert len(result) > 0 - assert type(result[0]) is expected_type, "The result type is not the expected" + assert ( + type(result[0]) is expected_type + ), "The result type is not the expected" -def test_find_all_should_return_items_from_specified_partition_key_value(cosmos_db_repository: CosmosDBRepository, - event_context: EventContext, - another_event_context: EventContext): +def test_find_all_should_return_items_from_specified_partition_key_value( + cosmos_db_repository: CosmosDBRepository, + event_context: EventContext, + another_event_context: EventContext, +): result_tenant_id = cosmos_db_repository.find_all(event_context) assert len(result_tenant_id) > 1 - assert all((i["tenant_id"] == event_context.tenant_id for i in result_tenant_id)) + assert all( + (i["tenant_id"] == event_context.tenant_id for i in result_tenant_id) + ) - result_another_tenant_id = cosmos_db_repository.find_all(another_event_context) + result_another_tenant_id = cosmos_db_repository.find_all( + another_event_context + ) assert len(result_another_tenant_id) > 0 - assert all((i["tenant_id"] == another_event_context.tenant_id for i in result_another_tenant_id)) + assert all( + ( + i["tenant_id"] == another_event_context.tenant_id + for i in result_another_tenant_id + ) + ) - assert not any(item in result_another_tenant_id for item in result_tenant_id), \ - "There should be no interceptions" + assert not any( + item in result_another_tenant_id for item in result_tenant_id + ), "There should be no interceptions" -def test_find_all_should_succeed_with_partition_key_value_with_no_items(cosmos_db_repository: CosmosDBRepository): +def test_find_all_should_succeed_with_partition_key_value_with_no_items( + cosmos_db_repository: CosmosDBRepository, +): invalid_event_context = EventContext("test", "any", tenant_id=fake.uuid4()) no_items = cosmos_db_repository.find_all(invalid_event_context) @@ -223,8 +280,9 @@ def test_find_all_should_succeed_with_partition_key_value_with_no_items(cosmos_d assert len(no_items) == 0, "No items are expected" -def test_find_all_with_max_count(cosmos_db_repository: CosmosDBRepository, - event_context: EventContext): +def test_find_all_with_max_count( + cosmos_db_repository: CosmosDBRepository, event_context: EventContext +): all_items = cosmos_db_repository.find_all(event_context) assert len(all_items) > 2 @@ -233,17 +291,22 @@ def test_find_all_with_max_count(cosmos_db_repository: CosmosDBRepository, assert len(first_two_items) == 2, "The result should be limited to 2" -def test_find_all_with_offset(cosmos_db_repository: CosmosDBRepository, - event_context: EventContext): +def test_find_all_with_offset( + cosmos_db_repository: CosmosDBRepository, event_context: EventContext +): result_all_items = cosmos_db_repository.find_all(event_context) assert len(result_all_items) >= 3 - result_after_the_first_item = cosmos_db_repository.find_all(event_context, offset=1) + result_after_the_first_item = cosmos_db_repository.find_all( + event_context, offset=1 + ) assert result_after_the_first_item == result_all_items[1:] - result_after_the_second_item = cosmos_db_repository.find_all(event_context, offset=2) + result_after_the_second_item = cosmos_db_repository.find_all( + event_context, offset=2 + ) assert result_after_the_second_item == result_all_items[2:] @@ -251,49 +314,58 @@ def test_find_all_with_offset(cosmos_db_repository: CosmosDBRepository, @pytest.mark.parametrize( 'mapper,expected_type', [(None, dict), (dict, dict), (Person, Person)] ) -def test_partial_update_with_mapper(cosmos_db_repository: CosmosDBRepository, - mapper: Callable, - sample_item: dict, - event_context: EventContext, - expected_type: Callable): +def test_partial_update_with_mapper( + cosmos_db_repository: CosmosDBRepository, + mapper: Callable, + sample_item: dict, + event_context: EventContext, + expected_type: Callable, +): changes = { 'name': fake.name(), 'email': fake.safe_email(), } - updated_item = cosmos_db_repository.partial_update(sample_item['id'], changes, - event_context, mapper=mapper) + updated_item = cosmos_db_repository.partial_update( + sample_item['id'], changes, event_context, mapper=mapper + ) assert updated_item is not None assert type(updated_item) is expected_type def test_partial_update_with_new_partition_key_value_should_fail( - cosmos_db_repository: CosmosDBRepository, - another_event_context: EventContext, - sample_item: dict): + cosmos_db_repository: CosmosDBRepository, + another_event_context: EventContext, + sample_item: dict, +): changes = { 'name': fake.name(), 'email': fake.safe_email(), } try: - cosmos_db_repository.partial_update(sample_item['id'], changes, another_event_context) + cosmos_db_repository.partial_update( + sample_item['id'], changes, another_event_context + ) fail('It should have failed') except Exception as e: assert type(e) is CosmosResourceNotFoundError assert e.status_code == 404 -def test_partial_update_with_invalid_id_should_fail(cosmos_db_repository: CosmosDBRepository, - event_context: EventContext): +def test_partial_update_with_invalid_id_should_fail( + cosmos_db_repository: CosmosDBRepository, event_context: EventContext +): changes = { 'name': fake.name(), 'email': fake.safe_email(), } try: - cosmos_db_repository.partial_update(fake.uuid4(), changes, event_context) + cosmos_db_repository.partial_update( + fake.uuid4(), changes, event_context + ) fail('It should have failed') except Exception as e: assert type(e) is CosmosResourceNotFoundError @@ -301,17 +373,18 @@ def test_partial_update_with_invalid_id_should_fail(cosmos_db_repository: Cosmos def test_partial_update_should_only_update_fields_in_changes( - cosmos_db_repository: CosmosDBRepository, - sample_item: dict, - event_context: EventContext): + cosmos_db_repository: CosmosDBRepository, + sample_item: dict, + event_context: EventContext, +): changes = { 'name': fake.name(), 'email': fake.safe_email(), } - updated_item = cosmos_db_repository.partial_update(sample_item['id'], - changes, - event_context) + updated_item = cosmos_db_repository.partial_update( + sample_item['id'], changes, event_context + ) assert updated_item is not None assert updated_item['name'] == changes["name"] != sample_item["name"] @@ -324,28 +397,29 @@ def test_partial_update_should_only_update_fields_in_changes( @pytest.mark.parametrize( 'mapper,expected_type', [(None, dict), (dict, dict), (Person, Person)] ) -def test_update_with_mapper(cosmos_db_repository: CosmosDBRepository, - mapper: Callable, - sample_item: dict, - event_context: EventContext, - expected_type: Callable): +def test_update_with_mapper( + cosmos_db_repository: CosmosDBRepository, + mapper: Callable, + sample_item: dict, + event_context: EventContext, + expected_type: Callable, +): changed_item = sample_item.copy() - changed_item.update({ - 'name': fake.name(), - 'email': fake.safe_email(), - }) + changed_item.update( + {'name': fake.name(), 'email': fake.safe_email(),} + ) - updated_item = cosmos_db_repository.update(sample_item['id'], - changed_item, - event_context, - mapper=mapper) + updated_item = cosmos_db_repository.update( + sample_item['id'], changed_item, event_context, mapper=mapper + ) assert updated_item is not None assert type(updated_item) is expected_type -def test_update_with_invalid_id_should_fail(cosmos_db_repository: CosmosDBRepository, - event_context: EventContext): +def test_update_with_invalid_id_should_fail( + cosmos_db_repository: CosmosDBRepository, event_context: EventContext +): changes = { 'name': fake.name(), 'email': fake.safe_email(), @@ -359,9 +433,11 @@ def test_update_with_invalid_id_should_fail(cosmos_db_repository: CosmosDBReposi assert e.status_code == 404 -def test_update_with_partial_changes_without_required_fields_it_should_fail(cosmos_db_repository: CosmosDBRepository, - event_context: EventContext, - sample_item: dict): +def test_update_with_partial_changes_without_required_fields_it_should_fail( + cosmos_db_repository: CosmosDBRepository, + event_context: EventContext, + sample_item: dict, +): changes = { 'id': sample_item['id'], 'email': fake.safe_email(), @@ -377,16 +453,19 @@ def test_update_with_partial_changes_without_required_fields_it_should_fail(cosm def test_update_with_partial_changes_with_required_fields_should_delete_the_missing_ones( - cosmos_db_repository: CosmosDBRepository, - event_context: EventContext, - sample_item: dict): + cosmos_db_repository: CosmosDBRepository, + event_context: EventContext, + sample_item: dict, +): changes = { 'id': fake.uuid4(), 'email': fake.safe_email(), 'tenant_id': event_context.tenant_id, } - updated_item = cosmos_db_repository.update(sample_item['id'], changes, event_context) + updated_item = cosmos_db_repository.update( + sample_item['id'], changes, event_context + ) assert updated_item is not None assert updated_item['id'] == changes["id"] != sample_item["id"] @@ -403,9 +482,11 @@ def test_update_with_partial_changes_with_required_fields_should_delete_the_miss assert e.status_code == 404 -def test_delete_with_invalid_id_should_fail(cosmos_db_repository: CosmosDBRepository, - event_context: EventContext, - tenant_id: str): +def test_delete_with_invalid_id_should_fail( + cosmos_db_repository: CosmosDBRepository, + event_context: EventContext, + tenant_id: str, +): try: cosmos_db_repository.delete(fake.uuid4(), event_context) except Exception as e: @@ -416,28 +497,38 @@ def test_delete_with_invalid_id_should_fail(cosmos_db_repository: CosmosDBReposi @pytest.mark.parametrize( 'mapper,expected_type', [(None, dict), (dict, dict), (Person, Person)] ) -def test_delete_with_mapper(cosmos_db_repository: CosmosDBRepository, - sample_item: dict, - event_context: EventContext, - mapper: Callable, - expected_type: Callable): - deleted_item = cosmos_db_repository.delete(sample_item['id'], event_context, mapper=mapper) +def test_delete_with_mapper( + cosmos_db_repository: CosmosDBRepository, + sample_item: dict, + event_context: EventContext, + mapper: Callable, + expected_type: Callable, +): + deleted_item = cosmos_db_repository.delete( + sample_item['id'], event_context, mapper=mapper + ) assert deleted_item is not None assert type(deleted_item) is expected_type try: - cosmos_db_repository.find(sample_item['id'], event_context, mapper=mapper) + cosmos_db_repository.find( + sample_item['id'], event_context, mapper=mapper + ) fail('It should have not found the deleted item') except Exception as e: assert type(e) is CosmosResourceNotFoundError assert e.status_code == 404 -def test_find_can_find_deleted_item_only_if_visibile_only_is_true(cosmos_db_repository: CosmosDBRepository, - event_context: EventContext, - sample_item: dict): - deleted_item = cosmos_db_repository.delete(sample_item['id'], event_context) +def test_find_can_find_deleted_item_only_if_visibile_only_is_true( + cosmos_db_repository: CosmosDBRepository, + event_context: EventContext, + sample_item: dict, +): + deleted_item = cosmos_db_repository.delete( + sample_item['id'], event_context + ) assert deleted_item is not None assert deleted_item['deleted'] is not None @@ -448,34 +539,48 @@ def test_find_can_find_deleted_item_only_if_visibile_only_is_true(cosmos_db_repo assert type(e) is CosmosResourceNotFoundError assert e.status_code == 404 - found_deleted_item = cosmos_db_repository.find(sample_item['id'], event_context, visible_only=False) + found_deleted_item = cosmos_db_repository.find( + sample_item['id'], event_context, visible_only=False + ) assert found_deleted_item is not None -def test_find_all_can_find_deleted_items_only_if_visibile_only_is_true(cosmos_db_repository: CosmosDBRepository, - event_context: EventContext, - sample_item: dict): - deleted_item = cosmos_db_repository.delete(sample_item['id'], event_context) +def test_find_all_can_find_deleted_items_only_if_visibile_only_is_true( + cosmos_db_repository: CosmosDBRepository, + event_context: EventContext, + sample_item: dict, +): + deleted_item = cosmos_db_repository.delete( + sample_item['id'], event_context + ) assert deleted_item is not None assert deleted_item['deleted'] is not None visible_items = cosmos_db_repository.find_all(event_context) assert visible_items is not None - assert any(item['id'] == sample_item['id'] for item in visible_items) == False, \ - 'The deleted item should not be visible' + assert ( + any(item['id'] == sample_item['id'] for item in visible_items) == False + ), 'The deleted item should not be visible' - all_items = cosmos_db_repository.find_all(event_context, visible_only=False) + all_items = cosmos_db_repository.find_all( + event_context, visible_only=False + ) assert all_items is not None - assert any(item['id'] == sample_item['id'] for item in all_items), \ - 'Deleted item should be visible' + assert any( + item['id'] == sample_item['id'] for item in all_items + ), 'Deleted item should be visible' -def test_delete_should_not_find_element_that_is_already_deleted(cosmos_db_repository: CosmosDBRepository, - event_context: EventContext, - sample_item: dict): - deleted_item = cosmos_db_repository.delete(sample_item['id'], event_context) +def test_delete_should_not_find_element_that_is_already_deleted( + cosmos_db_repository: CosmosDBRepository, + event_context: EventContext, + sample_item: dict, +): + deleted_item = cosmos_db_repository.delete( + sample_item['id'], event_context + ) assert deleted_item is not None @@ -487,10 +592,14 @@ def test_delete_should_not_find_element_that_is_already_deleted(cosmos_db_reposi assert e.status_code == 404 -def test_partial_update_should_not_find_element_that_is_already_deleted(cosmos_db_repository: CosmosDBRepository, - event_context: EventContext, - sample_item: dict): - deleted_item = cosmos_db_repository.delete(sample_item['id'], event_context) +def test_partial_update_should_not_find_element_that_is_already_deleted( + cosmos_db_repository: CosmosDBRepository, + event_context: EventContext, + sample_item: dict, +): + deleted_item = cosmos_db_repository.delete( + sample_item['id'], event_context + ) assert deleted_item is not None @@ -499,9 +608,9 @@ def test_partial_update_should_not_find_element_that_is_already_deleted(cosmos_d 'name': fake.name(), 'email': fake.safe_email(), } - cosmos_db_repository.partial_update(deleted_item['id'], - changes, - event_context) + cosmos_db_repository.partial_update( + deleted_item['id'], changes, event_context + ) fail('It should have not found the deleted item') except Exception as e: @@ -509,9 +618,11 @@ def test_partial_update_should_not_find_element_that_is_already_deleted(cosmos_d assert e.status_code == 404 -def test_delete_permanently_with_invalid_id_should_fail(cosmos_db_repository: CosmosDBRepository, - event_context: EventContext, - tenant_id: str): +def test_delete_permanently_with_invalid_id_should_fail( + cosmos_db_repository: CosmosDBRepository, + event_context: EventContext, + tenant_id: str, +): try: cosmos_db_repository.delete_permanently(fake.uuid4(), event_context) fail('It should have not found the deleted item') @@ -520,9 +631,11 @@ def test_delete_permanently_with_invalid_id_should_fail(cosmos_db_repository: Co assert e.status_code == 404 -def test_delete_permanently_with_valid_id_should_succeed(cosmos_db_repository: CosmosDBRepository, - event_context: EventContext, - sample_item: dict): +def test_delete_permanently_with_valid_id_should_succeed( + cosmos_db_repository: CosmosDBRepository, + event_context: EventContext, + sample_item: dict, +): found_item = cosmos_db_repository.find(sample_item['id'], event_context) assert found_item is not None @@ -538,52 +651,40 @@ def test_delete_permanently_with_valid_id_should_succeed(cosmos_db_repository: C assert e.status_code == 404 -def test_repository_create_sql_where_conditions_with_multiple_values(cosmos_db_repository: CosmosDBRepository): - result = cosmos_db_repository.create_sql_where_conditions({ - 'owner_id': 'mark', - 'customer_id': 'me' - }, "c") +def test_repository_create_sql_where_conditions_with_multiple_values( + cosmos_db_repository: CosmosDBRepository, +): + result = cosmos_db_repository.create_sql_where_conditions( + {'owner_id': 'mark', 'customer_id': 'me'}, "c" + ) assert result is not None - assert result == "AND c.owner_id = @owner_id AND c.customer_id = @customer_id" + assert ( + result == "AND c.owner_id = @owner_id AND c.customer_id = @customer_id" + ) -def test_repository_create_sql_where_conditions_with_no_values(cosmos_db_repository: CosmosDBRepository): +def test_repository_create_sql_where_conditions_with_no_values( + cosmos_db_repository: CosmosDBRepository, +): result = cosmos_db_repository.create_sql_where_conditions({}, "c") assert result is not None assert result == "" -def test_repository_append_conditions_values(cosmos_db_repository: CosmosDBRepository): - result = cosmos_db_repository.generate_params({'owner_id': 'mark', 'customer_id': 'ioet'}) +def test_repository_append_conditions_values( + cosmos_db_repository: CosmosDBRepository, +): + result = cosmos_db_repository.generate_params( + {'owner_id': 'mark', 'customer_id': 'ioet'} + ) assert result is not None - assert result == [{'name': '@owner_id', 'value': 'mark'}, - {'name': '@customer_id', 'value': 'ioet'}] - - -def test_find_should_call_picker_if_it_was_specified(cosmos_db_repository: CosmosDBRepository, - sample_item: dict, - event_context: EventContext, - another_item: dict): - def raise_bad_request_if_name_diff_the_one_from_sample_item(data: dict): - if sample_item['name'] != data['name']: - raise CustomError(HTTPStatus.BAD_REQUEST, "Anything") - - found_item = cosmos_db_repository.find(sample_item['id'], event_context) - - assert found_item is not None - assert found_item['id'] == sample_item['id'] - - try: - cosmos_db_repository.find(another_item['id'], event_context, - peeker=raise_bad_request_if_name_diff_the_one_from_sample_item) - - fail('It should have not found any item because of condition') - except Exception as e: - assert e.code == HTTPStatus.BAD_REQUEST - assert e.description == "Anything" + assert result == [ + {'name': '@owner_id', 'value': 'mark'}, + {'name': '@customer_id', 'value': 'ioet'}, + ] def test_datetime_str_comparison(): @@ -601,15 +702,17 @@ def test_datetime_str_comparison(): def test_replace_empty_value_per_none(tenant_id: str): - initial_value = dict(id=fake.uuid4(), - name=fake.name(), - empty_str_attrib="", - array_attrib=[1, 2, 3], - empty_array_attrib=[], - description=" ", - age=fake.pyint(min_value=10, max_value=80), - size=0, - tenant_id=tenant_id) + initial_value = dict( + id=fake.uuid4(), + name=fake.name(), + empty_str_attrib="", + array_attrib=[1, 2, 3], + empty_array_attrib=[], + description=" ", + age=fake.pyint(min_value=10, max_value=80), + size=0, + tenant_id=tenant_id, + ) input = initial_value.copy() @@ -625,9 +728,9 @@ def test_replace_empty_value_per_none(tenant_id: str): assert input["tenant_id"] == initial_value["tenant_id"] -def test_attach_context_should_create_last_event_context_attrib(owner_id: str, - tenant_id: str, - event_context: EventContext): +def test_attach_context_should_create_last_event_context_attrib( + owner_id: str, tenant_id: str, event_context: EventContext +): data = dict() CosmosDBRepository.real_attach_context(data, event_context) diff --git a/tests/conftest.py b/tests/conftest.py index 362df371..1f8afe71 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,9 @@ from time_tracker_api import create_app from time_tracker_api.database import init_sql from time_tracker_api.security import get_or_generate_dev_secret_key -from time_tracker_api.time_entries.time_entries_model import TimeEntryCosmosDBRepository +from time_tracker_api.time_entries.time_entries_model import ( + TimeEntryCosmosDBRepository, +) fake = Faker() Faker.seed() @@ -37,6 +39,7 @@ def sql_model_class(app: Flask): init_sql(app) from commons.data_access_layer.sql import db + class PersonSQLModel(db.Model): __tablename__ = 'test' id = db.Column(db.Integer, primary_key=True) @@ -59,27 +62,29 @@ def sql_repository(app: Flask, sql_model_class): init_app(app) from commons.data_access_layer.sql import db - db.metadata.create_all(bind=db.engine, tables=[sql_model_class.__table__]) + db.metadata.create_all( + bind=db.engine, tables=[sql_model_class.__table__] + ) app.logger.info("SQl test models created!") from commons.data_access_layer.sql import SQLRepository + yield SQLRepository(sql_model_class) - db.metadata.drop_all(bind=db.engine, tables=[sql_model_class.__table__]) + db.metadata.drop_all( + bind=db.engine, tables=[sql_model_class.__table__] + ) app.logger.info("SQL test models removed!") @pytest.fixture(scope="module") def cosmos_db_model(): from azure.cosmos import PartitionKey + return { 'id': 'test', 'partition_key': PartitionKey(path='/tenant_id'), - 'unique_key_policy': { - 'uniqueKeys': [ - {'paths': ['/email']}, - ] - } + 'unique_key_policy': {'uniqueKeys': [{'paths': ['/email']},]}, } @@ -104,11 +109,18 @@ def cosmos_db_repository(app: Flask, cosmos_db_model) -> CosmosDBRepository: @pytest.fixture(scope="module") -def cosmos_db_dao(app: Flask, cosmos_db_repository: CosmosDBRepository) -> CosmosDBDao: +def cosmos_db_dao( + app: Flask, cosmos_db_repository: CosmosDBRepository +) -> CosmosDBDao: with app.app_context(): return CosmosDBDao(cosmos_db_repository) +@pytest.fixture +def valid_id() -> str: + return fake.uuid4() + + @pytest.fixture(scope="session") def tenant_id() -> str: return fake.uuid4() @@ -125,27 +137,35 @@ def owner_id() -> str: @pytest.fixture(scope="function") -def sample_item(cosmos_db_repository: CosmosDBRepository, - tenant_id: str, - event_context: EventContext) -> dict: - sample_item_data = dict(id=fake.uuid4(), - name=fake.name(), - email=fake.safe_email(), - age=fake.pyint(min_value=10, max_value=80), - tenant_id=tenant_id) +def sample_item( + cosmos_db_repository: CosmosDBRepository, + tenant_id: str, + event_context: EventContext, +) -> dict: + sample_item_data = dict( + id=fake.uuid4(), + name=fake.name(), + email=fake.safe_email(), + age=fake.pyint(min_value=10, max_value=80), + tenant_id=tenant_id, + ) return cosmos_db_repository.create(sample_item_data, event_context) @pytest.fixture(scope="function") -def another_item(cosmos_db_repository: CosmosDBRepository, - tenant_id: str, - event_context: EventContext) -> dict: - sample_item_data = dict(id=fake.uuid4(), - name=fake.name(), - email=fake.safe_email(), - age=fake.pyint(min_value=10, max_value=80), - tenant_id=tenant_id) +def another_item( + cosmos_db_repository: CosmosDBRepository, + tenant_id: str, + event_context: EventContext, +) -> dict: + sample_item_data = dict( + id=fake.uuid4(), + name=fake.name(), + email=fake.safe_email(), + age=fake.pyint(min_value=10, max_value=80), + tenant_id=tenant_id, + ) return cosmos_db_repository.create(sample_item_data, event_context) @@ -162,31 +182,40 @@ def time_entry_repository(app: Flask) -> TimeEntryCosmosDBRepository: @pytest.yield_fixture(scope="module") -def running_time_entry(time_entry_repository: TimeEntryCosmosDBRepository, - owner_id: str, - tenant_id: str, - event_context: EventContext): - created_time_entry = time_entry_repository.create({ - "project_id": fake.uuid4(), - "owner_id": owner_id, - "tenant_id": tenant_id - }, event_context) +def running_time_entry( + time_entry_repository: TimeEntryCosmosDBRepository, + owner_id: str, + tenant_id: str, + event_context: EventContext, +): + created_time_entry = time_entry_repository.create( + { + "project_id": fake.uuid4(), + "owner_id": owner_id, + "tenant_id": tenant_id, + }, + event_context, + ) yield created_time_entry - time_entry_repository.delete_permanently(id=created_time_entry.id, - event_context=event_context) + time_entry_repository.delete_permanently( + id=created_time_entry.id, event_context=event_context + ) @pytest.fixture(scope="session") def valid_jwt(app: Flask, tenant_id: str, owner_id: str) -> str: with app.app_context(): expiration_time = datetime.utcnow() + timedelta(seconds=3600) - return jwt.encode({ - "iss": "https://ioetec.b2clogin.com/%s/v2.0/" % tenant_id, - "oid": owner_id, - 'exp': expiration_time - }, key=get_or_generate_dev_secret_key()).decode("UTF-8") + return jwt.encode( + { + "iss": "https://ioetec.b2clogin.com/%s/v2.0/" % tenant_id, + "oid": owner_id, + 'exp': expiration_time, + }, + key=get_or_generate_dev_secret_key(), + ).decode("UTF-8") @pytest.fixture(scope="session") @@ -196,13 +225,11 @@ def valid_header(valid_jwt: str) -> dict: @pytest.fixture(scope="session") def event_context(owner_id: str, tenant_id: str) -> EventContext: - return EventContext("test", "any", - user_id=owner_id, - tenant_id=tenant_id) + return EventContext("test", "any", user_id=owner_id, tenant_id=tenant_id) @pytest.fixture(scope="session") def another_event_context(another_tenant_id: str) -> EventContext: - return EventContext("test", "any", - user_id=fake.uuid4(), - tenant_id=another_tenant_id) + return EventContext( + "test", "any", user_id=fake.uuid4(), tenant_id=another_tenant_id + ) diff --git a/tests/time_tracker_api/time_entries/time_entries_namespace_test.py b/tests/time_tracker_api/time_entries/time_entries_namespace_test.py index a29dbc98..13fd9268 100644 --- a/tests/time_tracker_api/time_entries/time_entries_namespace_test.py +++ b/tests/time_tracker_api/time_entries/time_entries_namespace_test.py @@ -10,15 +10,13 @@ from commons.data_access_layer.cosmos_db import ( current_datetime, current_datetime_str, - get_date_range_of_month, get_current_month, get_current_year, ) -from time_tracker_api.time_entries.custom_modules import worked_time -from time_tracker_api.time_entries.time_entries_model import ( - TimeEntriesCosmosDBDao, -) +from utils import worked_time + +from werkzeug.exceptions import NotFound, UnprocessableEntity, HTTPException fake = Faker() @@ -39,7 +37,7 @@ fake_time_entry.update(valid_time_entry_input) -def test_create_time_entry_with_invalid_date_range_should_raise_bad_request_error( +def test_create_time_entry_with_invalid_date_range_should_raise_bad_request( client: FlaskClient, mocker: MockFixture, valid_header: dict ): from time_tracker_api.time_entries.time_entries_namespace import ( @@ -65,7 +63,7 @@ def test_create_time_entry_with_invalid_date_range_should_raise_bad_request_erro repository_container_create_item_mock.assert_not_called() -def test_create_time_entry_with_end_date_in_future_should_raise_bad_request_error( +def test_create_time_entry_with_end_date_in_future_should_raise_bad_request( client: FlaskClient, mocker: MockFixture, valid_header: dict ): from time_tracker_api.time_entries.time_entries_namespace import ( @@ -182,59 +180,64 @@ def test_get_time_entry_should_succeed_with_valid_id( dao_get_mock.assert_called_once_with(str(valid_id)) -def test_get_time_entry_should_response_with_unprocessable_entity_for_invalid_id_format( - client: FlaskClient, mocker: MockFixture, valid_header: dict +@pytest.mark.parametrize( + 'http_exception,http_status', + [ + (NotFound, HTTPStatus.NOT_FOUND), + (UnprocessableEntity, HTTPStatus.UNPROCESSABLE_ENTITY), + ], +) +def test_get_time_entry_raise_http_exception( + client: FlaskClient, + mocker: MockFixture, + valid_header: dict, + valid_id: str, + http_exception: HTTPException, + http_status: tuple, ): from time_tracker_api.time_entries.time_entries_namespace import ( time_entries_dao, ) - from werkzeug.exceptions import UnprocessableEntity - invalid_id = fake.word() - - repository_find_mock = mocker.patch.object( - time_entries_dao.repository, 'find', side_effect=UnprocessableEntity - ) + time_entries_dao.repository.find = Mock(side_effect=http_exception) response = client.get( - "/time-entries/%s" % invalid_id, + f"/time-entries/{valid_id}", headers=valid_header, follow_redirects=True, ) - assert HTTPStatus.UNPROCESSABLE_ENTITY == response.status_code - repository_find_mock.assert_called_once_with( - str(invalid_id), ANY, peeker=ANY - ) + assert http_status == response.status_code + time_entries_dao.repository.find.assert_called_once_with(valid_id, ANY) -def test_update_time_entry_should_succeed_with_valid_data( - client: FlaskClient, mocker: MockFixture, valid_header: dict +def test_update_time_entry_calls_partial_update_with_incoming_payload( + client: FlaskClient, mocker: MockFixture, valid_header: dict, valid_id: str ): from time_tracker_api.time_entries.time_entries_namespace import ( time_entries_dao, ) - repository_update_mock = mocker.patch.object( - time_entries_dao.repository, - 'partial_update', - return_value=fake_time_entry, - ) + time_entries_dao.repository.partial_update = Mock(return_value={}) + + time_entries_dao.repository.find = Mock(return_value={}) + time_entries_dao.check_whether_current_user_owns_item = Mock() - valid_id = fake.random_int(1, 9999) response = client.put( - "/time-entries/%s" % valid_id, + f'/time-entries/{valid_id}', headers=valid_header, json=valid_time_entry_input, follow_redirects=True, ) assert HTTPStatus.OK == response.status_code - fake_time_entry == json.loads(response.data) - repository_update_mock.assert_called_once_with( - str(valid_id), valid_time_entry_input, ANY, peeker=ANY + time_entries_dao.repository.partial_update.assert_called_once_with( + valid_id, valid_time_entry_input, ANY ) + time_entries_dao.repository.find.assert_called_once() + time_entries_dao.check_whether_current_user_owns_item.assert_called_once() + def test_update_time_entry_should_reject_bad_request( client: FlaskClient, mocker: MockFixture, valid_header: dict @@ -245,7 +248,7 @@ def test_update_time_entry_should_reject_bad_request( invalid_time_entry_data = valid_time_entry_input.copy() invalid_time_entry_data.update( - {"project_id": fake.pyint(min_value=1, max_value=100),} + {"project_id": fake.pyint(min_value=1, max_value=100)} ) repository_update_mock = mocker.patch.object( time_entries_dao.repository, 'update', return_value=fake_time_entry @@ -263,218 +266,204 @@ def test_update_time_entry_should_reject_bad_request( repository_update_mock.assert_not_called() -def test_update_time_entry_should_return_not_found_with_invalid_id( - client: FlaskClient, mocker: MockFixture, valid_header: dict +def test_update_time_entry_raise_not_found( + client: FlaskClient, mocker: MockFixture, valid_header: dict, valid_id: str ): from time_tracker_api.time_entries.time_entries_namespace import ( time_entries_dao, ) from werkzeug.exceptions import NotFound - repository_update_mock = mocker.patch.object( - time_entries_dao.repository, 'partial_update', side_effect=NotFound - ) - invalid_id = fake.random_int(1, 9999) + time_entries_dao.repository.partial_update = Mock(side_effect=NotFound) + + time_entries_dao.repository.find = Mock(return_value={}) + time_entries_dao.check_whether_current_user_owns_item = Mock() response = client.put( - "/time-entries/%s" % invalid_id, + f'/time-entries/{valid_id}', headers=valid_header, json=valid_time_entry_input, follow_redirects=True, ) assert HTTPStatus.NOT_FOUND == response.status_code - repository_update_mock.assert_called_once_with( - str(invalid_id), valid_time_entry_input, ANY, peeker=ANY + time_entries_dao.repository.partial_update.assert_called_once_with( + valid_id, valid_time_entry_input, ANY ) + time_entries_dao.repository.find.assert_called_once() + time_entries_dao.check_whether_current_user_owns_item.assert_called_once() -def test_delete_time_entry_should_succeed_with_valid_id( - client: FlaskClient, mocker: MockFixture, valid_header: dict + +def test_delete_time_entry_calls_delete( + client: FlaskClient, mocker: MockFixture, valid_header: dict, valid_id: str ): from time_tracker_api.time_entries.time_entries_namespace import ( time_entries_dao, ) - repository_remove_mock = mocker.patch.object( - time_entries_dao.repository, 'delete', return_value=None - ) - valid_id = fake.random_int(1, 9999) - + time_entries_dao.repository.delete = Mock(return_value=None) + time_entries_dao.repository.find = Mock() + time_entries_dao.check_whether_current_user_owns_item = Mock() response = client.delete( - "/time-entries/%s" % valid_id, + f'/time-entries/{valid_id}', headers=valid_header, follow_redirects=True, ) assert HTTPStatus.NO_CONTENT == response.status_code assert b'' == response.data - repository_remove_mock.assert_called_once_with( - str(valid_id), ANY, peeker=ANY - ) + time_entries_dao.repository.delete.assert_called_once_with(valid_id, ANY) + time_entries_dao.repository.find.assert_called_once() + time_entries_dao.check_whether_current_user_owns_item.assert_called_once() -def test_delete_time_entry_should_return_not_found_with_invalid_id( - client: FlaskClient, mocker: MockFixture, valid_header: dict +@pytest.mark.parametrize( + 'http_exception,http_status', + [ + (NotFound, HTTPStatus.NOT_FOUND), + (UnprocessableEntity, HTTPStatus.UNPROCESSABLE_ENTITY), + ], +) +def test_delete_time_entry_raise_http_exception( + client: FlaskClient, + mocker: MockFixture, + valid_header: dict, + valid_id: str, + http_exception: HTTPException, + http_status: tuple, ): from time_tracker_api.time_entries.time_entries_namespace import ( time_entries_dao, ) - from werkzeug.exceptions import NotFound - repository_remove_mock = mocker.patch.object( - time_entries_dao.repository, 'delete', side_effect=NotFound - ) - invalid_id = fake.random_int(1, 9999) + time_entries_dao.repository.delete = Mock(side_effect=http_exception) + time_entries_dao.repository.find = Mock() + time_entries_dao.check_whether_current_user_owns_item = Mock() response = client.delete( - "/time-entries/%s" % invalid_id, + f"/time-entries/{valid_id}", headers=valid_header, follow_redirects=True, ) - assert HTTPStatus.NOT_FOUND == response.status_code - repository_remove_mock.assert_called_once_with( - str(invalid_id), ANY, peeker=ANY - ) + assert http_status == response.status_code + time_entries_dao.repository.delete.assert_called_once_with(valid_id, ANY) + time_entries_dao.repository.find.assert_called_once() + time_entries_dao.check_whether_current_user_owns_item.assert_called_once() -def test_delete_time_entry_should_return_unprocessable_entity_for_invalid_id_format( - client: FlaskClient, mocker: MockFixture, valid_header: dict +def test_stop_time_entry_calls_partial_update( + client: FlaskClient, mocker: MockFixture, valid_header: dict, valid_id: str ): from time_tracker_api.time_entries.time_entries_namespace import ( time_entries_dao, ) - from werkzeug.exceptions import UnprocessableEntity - - repository_remove_mock = mocker.patch.object( - time_entries_dao.repository, 'delete', side_effect=UnprocessableEntity - ) - invalid_id = fake.word() - response = client.delete( - "/time-entries/%s" % invalid_id, - headers=valid_header, - follow_redirects=True, - ) - - assert HTTPStatus.UNPROCESSABLE_ENTITY == response.status_code - repository_remove_mock.assert_called_once_with( - str(invalid_id), ANY, peeker=ANY - ) - - -def test_stop_time_entry_with_valid_id( - client: FlaskClient, mocker: MockFixture, valid_header: dict -): - from time_tracker_api.time_entries.time_entries_namespace import ( - time_entries_dao, - ) + time_entries_dao.repository.partial_update = Mock(return_value={}) - repository_update_mock = mocker.patch.object( - time_entries_dao.repository, - 'partial_update', - return_value=fake_time_entry, - ) - valid_id = fake.random_int(1, 9999) + time_entries_dao.repository.find = Mock(return_value={}) + time_entries_dao.check_time_entry_is_not_stopped = Mock() + time_entries_dao.check_whether_current_user_owns_item = Mock() response = client.post( - "/time-entries/%s/stop" % valid_id, + f'/time-entries/{valid_id}/stop', headers=valid_header, follow_redirects=True, ) assert HTTPStatus.OK == response.status_code - repository_update_mock.assert_called_once_with( - str(valid_id), - {"end_date": mocker.ANY}, - ANY, - peeker=TimeEntriesCosmosDBDao.checks_owner_and_is_not_stopped, + time_entries_dao.repository.partial_update.assert_called_once_with( + valid_id, {"end_date": ANY}, ANY ) + time_entries_dao.check_time_entry_is_not_stopped.assert_called_once() + time_entries_dao.check_whether_current_user_owns_item.assert_called_once() -def test_stop_time_entry_with_id_with_invalid_format( - client: FlaskClient, mocker: MockFixture, valid_header: dict +def test_stop_time_entry_raise_unprocessable_entity( + client: FlaskClient, mocker: MockFixture, valid_header: dict, valid_id: str ): from time_tracker_api.time_entries.time_entries_namespace import ( time_entries_dao, ) from werkzeug.exceptions import UnprocessableEntity - repository_update_mock = mocker.patch.object( - time_entries_dao.repository, - 'partial_update', - side_effect=UnprocessableEntity, + time_entries_dao.repository.partial_update = Mock( + side_effect=UnprocessableEntity ) - invalid_id = fake.word() + time_entries_dao.repository.find = Mock(return_value={}) + time_entries_dao.check_time_entry_is_not_stopped = Mock() + time_entries_dao.check_whether_current_user_owns_item = Mock() response = client.post( - "/time-entries/%s/stop" % invalid_id, + f'/time-entries/{valid_id}/stop', headers=valid_header, follow_redirects=True, ) assert HTTPStatus.UNPROCESSABLE_ENTITY == response.status_code - repository_update_mock.assert_called_once_with( - invalid_id, - {"end_date": ANY}, - ANY, - peeker=TimeEntriesCosmosDBDao.checks_owner_and_is_not_stopped, + time_entries_dao.repository.partial_update.assert_called_once_with( + valid_id, {"end_date": ANY}, ANY ) + time_entries_dao.check_whether_current_user_owns_item.assert_called_once() + time_entries_dao.check_time_entry_is_not_stopped.assert_called_once() -def test_restart_time_entry_with_valid_id( - client: FlaskClient, mocker: MockFixture, valid_header: dict +def test_restart_time_entry_calls_partial_update( + client: FlaskClient, mocker: MockFixture, valid_header: dict, valid_id: str ): from time_tracker_api.time_entries.time_entries_namespace import ( time_entries_dao, ) - repository_update_mock = mocker.patch.object( - time_entries_dao.repository, - 'partial_update', - return_value=fake_time_entry, - ) - valid_id = fake.random_int(1, 9999) + time_entries_dao.repository.partial_update = Mock(return_value={}) + + time_entries_dao.repository.find = Mock(return_value={}) + time_entries_dao.check_time_entry_is_not_started = Mock() + time_entries_dao.check_whether_current_user_owns_item = Mock() response = client.post( - "/time-entries/%s/restart" % valid_id, + f'/time-entries/{valid_id}/restart', headers=valid_header, follow_redirects=True, ) assert HTTPStatus.OK == response.status_code - repository_update_mock.assert_called_once_with( - str(valid_id), {"end_date": None}, ANY, peeker=ANY + time_entries_dao.repository.partial_update.assert_called_once_with( + valid_id, {"end_date": None}, ANY ) + time_entries_dao.check_time_entry_is_not_started.assert_called_once() + time_entries_dao.check_whether_current_user_owns_item.assert_called_once() -def test_restart_time_entry_with_id_with_invalid_format( - client: FlaskClient, mocker: MockFixture, valid_header: dict +def test_restart_time_entry_raise_unprocessable_entity( + client: FlaskClient, mocker: MockFixture, valid_header: dict, valid_id: str ): from time_tracker_api.time_entries.time_entries_namespace import ( time_entries_dao, ) from werkzeug.exceptions import UnprocessableEntity - repository_update_mock = mocker.patch.object( - time_entries_dao.repository, - 'partial_update', - side_effect=UnprocessableEntity, - peeker=ANY, + time_entries_dao.repository.partial_update = Mock( + side_effect=UnprocessableEntity ) - invalid_id = fake.word() + + time_entries_dao.repository.find = Mock(return_value={}) + time_entries_dao.check_time_entry_is_not_started = Mock() + time_entries_dao.check_whether_current_user_owns_item = Mock() response = client.post( - "/time-entries/%s/restart" % invalid_id, + f'/time-entries/{valid_id}/restart', headers=valid_header, follow_redirects=True, ) assert HTTPStatus.UNPROCESSABLE_ENTITY == response.status_code - repository_update_mock.assert_called_once_with( - invalid_id, {"end_date": None}, ANY, peeker=ANY + time_entries_dao.repository.partial_update.assert_called_once_with( + valid_id, {"end_date": None}, ANY ) + time_entries_dao.check_time_entry_is_not_started.assert_called_once() + time_entries_dao.check_whether_current_user_owns_item.assert_called_once() def test_get_running_should_call_find_running( @@ -503,7 +492,7 @@ def test_get_running_should_call_find_running( repository_update_mock.assert_called_once_with(tenant_id, owner_id) -def test_get_running_should_return_not_found_if_find_running_throws_StopIteration( +def test_get_running_should_return_not_found_if_StopIteration( client: FlaskClient, mocker: MockFixture, valid_header: dict, diff --git a/time_tracker_api/projects/custom_modules/utils.py b/time_tracker_api/projects/custom_modules/utils.py deleted file mode 100644 index 1d4ba1bb..00000000 --- a/time_tracker_api/projects/custom_modules/utils.py +++ /dev/null @@ -1,9 +0,0 @@ -# TODO this must be refactored to be used from the utils module ↓ -# Also check if we can change this using the overwritten __add__ method - - -def add_customer_name_to_projects(projects, customers): - for project in projects: - for customer in customers: - if project.customer_id == customer.id: - setattr(project, 'customer_name', customer.name) diff --git a/time_tracker_api/projects/projects_model.py b/time_tracker_api/projects/projects_model.py index 54d7c0ba..62cba129 100644 --- a/time_tracker_api/projects/projects_model.py +++ b/time_tracker_api/projects/projects_model.py @@ -1,13 +1,18 @@ from dataclasses import dataclass from azure.cosmos import PartitionKey -from commons.data_access_layer.cosmos_db import CosmosDBModel, CosmosDBDao, CosmosDBRepository +from commons.data_access_layer.cosmos_db import ( + CosmosDBModel, + CosmosDBDao, + CosmosDBRepository, +) from time_tracker_api.database import CRUDDao, APICosmosDBDao -from time_tracker_api.customers.customers_model import create_dao as customers_create_dao +from time_tracker_api.customers.customers_model import ( + create_dao as customers_create_dao, +) from time_tracker_api.customers.customers_model import CustomerCosmosDBModel -from time_tracker_api.projects.custom_modules.utils import ( - add_customer_name_to_projects -) +from utils.extend_model import add_customer_name_to_projects + class ProjectDao(CRUDDao): pass @@ -17,10 +22,8 @@ class ProjectDao(CRUDDao): 'id': 'project', 'partition_key': PartitionKey(path='/tenant_id'), 'unique_key_policy': { - 'uniqueKeys': [ - {'paths': ['/name', '/customer_id', '/deleted']}, - ] - } + 'uniqueKeys': [{'paths': ['/name', '/customer_id', '/deleted']},] + }, } @@ -36,7 +39,7 @@ class ProjectCosmosDBModel(CosmosDBModel): technologies: list def __init__(self, data): - super(ProjectCosmosDBModel, self).__init__(data) # pragma: no cover + super(ProjectCosmosDBModel, self).__init__(data) # pragma: no cover def __contains__(self, item): if type(item) is CustomerCosmosDBModel: @@ -53,9 +56,12 @@ def __str___(self): class ProjectCosmosDBRepository(CosmosDBRepository): def __init__(self): - CosmosDBRepository.__init__(self, container_id=container_definition['id'], - partition_key_attribute='tenant_id', - mapper=ProjectCosmosDBModel) + CosmosDBRepository.__init__( + self, + container_id=container_definition['id'], + partition_key_attribute='tenant_id', + mapper=ProjectCosmosDBModel, + ) class ProjectCosmosDBDao(APICosmosDBDao, ProjectDao): @@ -75,12 +81,14 @@ def get_all(self, conditions: dict = None, **kwargs) -> list: customers_id = [customer.id for customer in customers] conditions = conditions if conditions else {} - custom_condition = "c.customer_id IN {}".format(str(tuple(customers_id))) + custom_condition = "c.customer_id IN {}".format( + str(tuple(customers_id)) + ) # TODO this must be refactored to be used from the utils module ↑ if "custom_sql_conditions" in kwargs: kwargs["custom_sql_conditions"].append(custom_condition) else: - kwargs["custom_sql_conditions"] = [custom_condition] + kwargs["custom_sql_conditions"] = [custom_condition] projects = self.repository.find_all(event_ctx, conditions, **kwargs) add_customer_name_to_projects(projects, customers) diff --git a/time_tracker_api/time_entries/custom_modules/utils.py b/time_tracker_api/time_entries/custom_modules/utils.py deleted file mode 100644 index e2ed09d5..00000000 --- a/time_tracker_api/time_entries/custom_modules/utils.py +++ /dev/null @@ -1,9 +0,0 @@ -# TODO this must be refactored to be used from the utils module ↓ -# Also check if we can improve this by using the overwritten __add__ method - - -def add_project_name_to_time_entries(time_entries, projects): - for time_entry in time_entries: - for project in projects: - if time_entry.project_id == project.id: - setattr(time_entry, 'project_name', project.name) diff --git a/time_tracker_api/time_entries/time_entries_model.py b/time_tracker_api/time_entries/time_entries_model.py index e2bcdc19..644e3834 100644 --- a/time_tracker_api/time_entries/time_entries_model.py +++ b/time_tracker_api/time_entries/time_entries_model.py @@ -17,11 +17,10 @@ ) from commons.data_access_layer.database import EventContext -from time_tracker_api.time_entries.custom_modules import worked_time -from time_tracker_api.time_entries.custom_modules.utils import ( - add_project_name_to_time_entries, -) -from time_tracker_api.projects.projects_model import ProjectCosmosDBModel, create_dao as project_create_dao +from utils.extend_model import add_project_name_to_time_entries +from utils import worked_time + +from time_tracker_api.projects.projects_model import ProjectCosmosDBModel from time_tracker_api.projects import projects_model from time_tracker_api.database import CRUDDao, APICosmosDBDao from time_tracker_api.security import current_user_id @@ -142,11 +141,17 @@ def find_all( if time_entries: projects_id = [project.project_id for project in time_entries] - p_ids = str(tuple(projects_id)).replace(",", "") if len(projects_id) == 1 else str(tuple(projects_id)) + p_ids = ( + str(tuple(projects_id)).replace(",", "") + if len(projects_id) == 1 + else str(tuple(projects_id)) + ) custom_conditions = "c.id IN {}".format(p_ids) # TODO this must be refactored to be used from the utils module ↑ project_dao = projects_model.create_dao() - projects = project_dao.get_all(custom_sql_conditions=[custom_conditions]) + projects = project_dao.get_all( + custom_sql_conditions=[custom_conditions] + ) add_project_name_to_time_entries(time_entries, projects) return time_entries @@ -270,32 +275,25 @@ class TimeEntriesCosmosDBDao(APICosmosDBDao, TimeEntriesDao): def __init__(self, repository): CosmosDBDao.__init__(self, repository) - @classmethod - def check_whether_current_user_owns_item(cls, data: dict): + def check_whether_current_user_owns_item(self, data): if ( - data.get('owner_id') is not None - and data.get('owner_id') != cls.current_user_id() + data.owner_id is not None + and data.owner_id != self.current_user_id() ): raise CustomError( HTTPStatus.FORBIDDEN, "The current user is not the owner of this time entry", ) - @classmethod - def checks_owner_and_is_not_stopped(cls, data: dict): - cls.check_whether_current_user_owns_item(data) - - if data.get('end_date') is not None: + def check_time_entry_is_not_stopped(self, data): + if data.end_date is not None: raise CustomError( HTTPStatus.UNPROCESSABLE_ENTITY, "The specified time entry is already stopped", ) - @classmethod - def checks_owner_and_is_not_started(cls, data: dict): - cls.check_whether_current_user_owns_item(data) - - if data.get('end_date') is None: + def check_time_entry_is_not_started(self, data): + if data.end_date is None: raise CustomError( HTTPStatus.UNPROCESSABLE_ENTITY, "The specified time entry is already running", @@ -306,13 +304,15 @@ def get_all(self, conditions: dict = None, **kwargs) -> list: conditions.update({"owner_id": event_ctx.user_id}) date_range = self.handle_date_filter_args(args=conditions) - return self.repository.find_all(event_ctx, conditions=conditions, date_range=date_range) + return self.repository.find_all( + event_ctx, conditions=conditions, date_range=date_range + ) def get(self, id): event_ctx = self.create_event_context("read") - time_entry = self.repository.find( - id, event_ctx, peeker=self.check_whether_current_user_owns_item - ) + + time_entry = self.repository.find(id, event_ctx) + self.check_whether_current_user_owns_item(time_entry) project_dao = projects_model.create_dao() project = project_dao.get(time_entry.project_id) @@ -326,35 +326,40 @@ def create(self, data: dict): def update(self, id, data: dict, description=None): event_ctx = self.create_event_context("update", description) - return self.repository.partial_update( - id, - data, - event_ctx, - peeker=self.check_whether_current_user_owns_item, - ) + + time_entry = self.repository.find(id, event_ctx) + self.check_whether_current_user_owns_item(time_entry) + + return self.repository.partial_update(id, data, event_ctx,) def stop(self, id): event_ctx = self.create_event_context("update", "Stop time entry") + + time_entry = self.repository.find(id, event_ctx) + self.check_whether_current_user_owns_item(time_entry) + self.check_time_entry_is_not_stopped(time_entry) + return self.repository.partial_update( - id, - {'end_date': current_datetime_str()}, - event_ctx, - peeker=self.checks_owner_and_is_not_stopped, + id, {'end_date': current_datetime_str()}, event_ctx, ) def restart(self, id): event_ctx = self.create_event_context("update", "Restart time entry") + + time_entry = self.repository.find(id, event_ctx) + self.check_whether_current_user_owns_item(time_entry) + self.check_time_entry_is_not_started(time_entry) + return self.repository.partial_update( - id, - {'end_date': None}, - event_ctx, - peeker=self.checks_owner_and_is_not_started, + id, {'end_date': None}, event_ctx, ) def delete(self, id): event_ctx = self.create_event_context("delete") + time_entry = self.repository.find(id, event_ctx) + self.check_whether_current_user_owns_item(time_entry) self.repository.delete( - id, event_ctx, peeker=self.check_whether_current_user_owns_item + id, event_ctx, ) def find_running(self): @@ -397,7 +402,9 @@ def handle_date_filter_args(args: dict) -> dict: else: month = get_current_month() year = get_current_year() - return date_range if date_range else get_date_range_of_month(year, month) + return ( + date_range if date_range else get_date_range_of_month(year, month) + ) def create_dao() -> TimeEntriesDao: diff --git a/utils/extend_model.py b/utils/extend_model.py new file mode 100644 index 00000000..cd33cc5b --- /dev/null +++ b/utils/extend_model.py @@ -0,0 +1,28 @@ +def add_customer_name_to_projects(projects, customers): + """ + Add attribute customer_name in project model, based on customer_id of the + project + :param (list) projects: projects retrieved from project repository + :param (list) customers: customers retrieved from customer repository + + TODO : check if we can improve this by using the overwritten __add__ method + """ + for project in projects: + for customer in customers: + if project.customer_id == customer.id: + setattr(project, 'customer_name', customer.name) + + +def add_project_name_to_time_entries(time_entries, projects): + """ + Add attribute project_name in time-entry model, based on project_id of the + time_entry + :param (list) time_entries: time_entries retrieved from time-entry repository + :param (list) projects: projects retrieved from project repository + + TODO : check if we can improve this by using the overwritten __add__ method + """ + for time_entry in time_entries: + for project in projects: + if time_entry.project_id == project.id: + setattr(time_entry, 'project_name', project.name) diff --git a/time_tracker_api/time_entries/custom_modules/worked_time.py b/utils/worked_time.py similarity index 94% rename from time_tracker_api/time_entries/custom_modules/worked_time.py rename to utils/worked_time.py index 29475b63..2b7e0862 100644 --- a/time_tracker_api/time_entries/custom_modules/worked_time.py +++ b/utils/worked_time.py @@ -4,12 +4,17 @@ current_datetime_str, datetime_str, get_current_month, - get_current_year + get_current_year, ) def start_datetime_of_current_month() -> datetime: - return datetime(year=get_current_year(), month=get_current_month(), day=1, tzinfo=timezone.utc) + return datetime( + year=get_current_year(), + month=get_current_month(), + day=1, + tzinfo=timezone.utc, + ) def start_datetime_of_current_week() -> datetime: @@ -33,7 +38,9 @@ def str_to_datetime( value: str, conversion_format: str = '%Y-%m-%dT%H:%M:%S.%fZ' ) -> datetime: if 'Z' in value: - return datetime.strptime(value, conversion_format).astimezone(timezone.utc) + return datetime.strptime(value, conversion_format).astimezone( + timezone.utc + ) else: return datetime.fromisoformat(value).astimezone(timezone.utc)