diff --git a/commons/data_access_layer/cosmos_db.py b/commons/data_access_layer/cosmos_db.py index 40416cec..10bc685d 100644 --- a/commons/data_access_layer/cosmos_db.py +++ b/commons/data_access_layer/cosmos_db.py @@ -2,7 +2,7 @@ import logging import uuid from datetime import datetime -from typing import Callable, List, Tuple +from typing import Callable, List, Dict import azure.cosmos.cosmos_client as cosmos_client import azure.cosmos.exceptions as exceptions @@ -326,7 +326,7 @@ def get_current_month() -> int: def get_date_range_of_month( year: int, month: int -) -> Tuple[datetime, datetime]: +) -> Dict[str, str]: first_day_of_month = 1 start_date = datetime(year=year, month=month, day=first_day_of_month) @@ -340,4 +340,8 @@ def get_date_range_of_month( second=59, microsecond=999999 ) - return start_date, end_date + + return { + 'start_date': datetime_str(start_date), + 'end_date': datetime_str(end_date) + } diff --git a/tests/time_tracker_api/time_entries/time_entries_namespace_test.py b/tests/time_tracker_api/time_entries/time_entries_namespace_test.py index 19a7561b..35140fee 100644 --- a/tests/time_tracker_api/time_entries/time_entries_namespace_test.py +++ b/tests/time_tracker_api/time_entries/time_entries_namespace_test.py @@ -10,6 +10,7 @@ from commons.data_access_layer.cosmos_db import current_datetime, \ current_datetime_str, get_date_range_of_month, get_current_month, \ get_current_year, datetime_str +from commons.data_access_layer.database import EventContext from time_tracker_api.time_entries.time_entries_model import TimeEntriesCosmosDBDao fake = Faker() @@ -449,7 +450,6 @@ def test_create_with_valid_uuid_format_should_return_created(client: FlaskClient def test_find_all_is_called_with_generated_dates(client: FlaskClient, mocker: MockFixture, valid_header: dict, - tenant_id: str, owner_id: str, url: str, month: int, @@ -459,23 +459,15 @@ def test_find_all_is_called_with_generated_dates(client: FlaskClient, 'find_all', return_value=fake_time_entry) - response = client.get(url, - headers=valid_header, - follow_redirects=True) - - start_date, end_date = get_date_range_of_month(year, month) - custom_args = { - 'start_date': datetime_str(start_date), - 'end_date': datetime_str(end_date) - } + response = client.get(url, headers=valid_header, follow_redirects=True) + date_range = get_date_range_of_month(year, month) conditions = { 'owner_id': owner_id } assert HTTPStatus.OK == response.status_code assert json.loads(response.data) is not None - repository_find_all_mock.assert_called_once_with(partition_key_value=tenant_id, - conditions=conditions, - custom_args=custom_args) - + repository_find_all_mock.assert_called_once_with(ANY, + conditions=conditions, + date_range=date_range) diff --git a/time_tracker_api/time_entries/time_entries_model.py b/time_tracker_api/time_entries/time_entries_model.py index 32d97c89..10cd7989 100644 --- a/time_tracker_api/time_entries/time_entries_model.py +++ b/time_tracker_api/time_entries/time_entries_model.py @@ -221,25 +221,7 @@ def get_all(self, conditions: dict = {}) -> list: event_ctx = self.create_event_context("read-many") conditions.update({"owner_id": event_ctx.user_id}) - if 'month' and 'year' in conditions: - month = int(conditions.get("month")) - year = int(conditions.get("year")) - conditions.pop('month') - conditions.pop('year') - elif 'month' in conditions: - month = int(conditions.get("month")) - year = get_current_year() - conditions.pop('month') - else: - month = get_current_month() - year = get_current_year() - - start_date, end_date = get_date_range_of_month(year, month) - - date_range = { - 'start_date': start_date.isoformat(), - 'end_date': end_date.isoformat(), - } + date_range = self.handle_date_filter_args(args=conditions) return self.repository.find_all(event_ctx, conditions=conditions, date_range=date_range) @@ -277,6 +259,22 @@ def find_running(self): event_ctx = self.create_event_context("find_running") return self.repository.find_running(event_ctx.tenant_id, event_ctx.user_id) + @staticmethod + def handle_date_filter_args(args: dict) -> dict: + if 'month' and 'year' in args: + month = int(args.get("month")) + year = int(args.get("year")) + args.pop('month') + args.pop('year') + elif 'month' in args: + month = int(args.get("month")) + year = get_current_year() + args.pop('month') + else: + month = get_current_month() + year = get_current_year() + return get_date_range_of_month(year, month) + def create_dao() -> TimeEntriesDao: repository = TimeEntryCosmosDBRepository()