Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 53 additions & 6 deletions commons/data_access_layer/cosmos_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import uuid
from datetime import datetime
from typing import Callable
from typing import Callable, List, Dict

import azure.cosmos.cosmos_client as cosmos_client
import azure.cosmos.exceptions as exceptions
Expand Down Expand Up @@ -116,7 +116,15 @@ def create_sql_where_conditions(conditions: dict, container_name='c') -> str:
return ""

@staticmethod
def generate_condition_values(conditions: dict) -> dict:
def create_custom_sql_conditions(custom_sql_conditions: List[str]) -> str:
if len(custom_sql_conditions) > 0:
return "AND {custom_sql_conditions_clause}".format(
custom_sql_conditions_clause=" AND ".join(custom_sql_conditions))
else:
return ''

@staticmethod
def generate_params(conditions: dict) -> dict:
result = []
for k, v in conditions.items():
result.append({
Expand Down Expand Up @@ -166,23 +174,25 @@ def find(self, id: str, event_context: EventContext, peeker: 'function' = None,
function_mapper = self.get_mapper_or_dict(mapper)
return function_mapper(self.check_visibility(found_item, visible_only))

def find_all(self, event_context: EventContext, conditions: dict = {}, max_count=None,
offset=0, visible_only=True, mapper: Callable = None):
def find_all(self, event_context: EventContext, conditions: dict = {}, custom_sql_conditions: List[str] = [],
custom_params: dict = {}, max_count=None, offset=0, visible_only=True, mapper: Callable = None):
partition_key_value = self.find_partition_key_value(event_context)
max_count = self.get_page_size_or(max_count)
params = [
{"name": "@partition_key_value", "value": partition_key_value},
{"name": "@offset", "value": offset},
{"name": "@max_count", "value": max_count},
]
params.extend(self.generate_condition_values(conditions))
params.extend(self.generate_params(conditions))
params.extend(custom_params)
result = self.container.query_items(query="""
SELECT * FROM c WHERE c.{partition_key_attribute}=@partition_key_value
{conditions_clause} {visibility_condition} {order_clause}
{conditions_clause} {visibility_condition} {custom_sql_conditions_clause} {order_clause}
OFFSET @offset LIMIT @max_count
""".format(partition_key_attribute=self.partition_key_attribute,
visibility_condition=self.create_sql_condition_for_visibility(visible_only),
conditions_clause=self.create_sql_where_conditions(conditions),
custom_sql_conditions_clause=self.create_custom_sql_conditions(custom_sql_conditions),
order_clause=self.create_sql_order_clause()),
parameters=params,
partition_key=partition_key_value,
Expand Down Expand Up @@ -298,3 +308,40 @@ def generate_uuid4() -> str:
def init_app(app: Flask) -> None:
global cosmos_helper
cosmos_helper = CosmosDBFacade.from_flask_config(app)


def get_last_day_of_month(year: int, month: int) -> int:
from calendar import monthrange
return monthrange(year=year, month=month)[1]


def get_current_year() -> int:
return datetime.now().year


def get_current_month() -> int:
return datetime.now().month


def get_date_range_of_month(
year: int,
month: int
) -> Dict[str, str]:
first_day_of_month = 1
start_date = datetime(year=year, month=month, day=first_day_of_month)

last_day_of_month = get_last_day_of_month(year=year, month=month)
end_date = datetime(
year=year,
month=month,
day=last_day_of_month,
hour=23,
minute=59,
second=59,
microsecond=999999
)

return {
'start_date': datetime_str(start_date),
'end_date': datetime_str(end_date)
}
2 changes: 1 addition & 1 deletion tests/commons/data_access_layer/cosmos_db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def test_repository_create_sql_where_conditions_with_no_values(cosmos_db_reposit


def test_repository_append_conditions_values(cosmos_db_repository: CosmosDBRepository):
result = cosmos_db_repository.generate_condition_values({'owner_id': 'mark', 'customer_id': 'ioet'})
result = cosmos_db_repository.generate_params({'owner_id': 'mark', 'customer_id': 'ioet'})

assert result is not None
assert result == [{'name': '@owner_id', 'value': 'mark'},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from flask_restplus._http import HTTPStatus
from pytest_mock import MockFixture, pytest

from commons.data_access_layer.cosmos_db import current_datetime, current_datetime_str
from commons.data_access_layer.cosmos_db import current_datetime, \
current_datetime_str, get_date_range_of_month, get_current_month, \
get_current_year, datetime_str
from commons.data_access_layer.database import EventContext
from time_tracker_api.time_entries.time_entries_model import TimeEntriesCosmosDBDao

fake = Faker()
Expand Down Expand Up @@ -434,3 +437,37 @@ def test_create_with_valid_uuid_format_should_return_created(client: FlaskClient

assert HTTPStatus.CREATED == response.status_code
repository_container_create_item_mock.assert_called()


@pytest.mark.parametrize(
'url,month,year',
[
('/time-entries?month=4&year=2020', 4, 2020),
('/time-entries?month=4', 4, get_current_year()),
('/time-entries', get_current_month(), get_current_year())
]
)
def test_find_all_is_called_with_generated_dates(client: FlaskClient,
mocker: MockFixture,
valid_header: dict,
owner_id: str,
url: str,
month: int,
year: int):
from time_tracker_api.time_entries.time_entries_namespace import time_entries_dao
repository_find_all_mock = mocker.patch.object(time_entries_dao.repository,
'find_all',
return_value=fake_time_entry)

response = client.get(url, headers=valid_header, follow_redirects=True)

date_range = get_date_range_of_month(year, month)
conditions = {
'owner_id': owner_id
}

assert HTTPStatus.OK == response.status_code
assert json.loads(response.data) is not None
repository_find_all_mock.assert_called_once_with(ANY,
conditions=conditions,
date_range=date_range)
54 changes: 50 additions & 4 deletions time_tracker_api/time_entries/time_entries_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from flask_restplus._http import HTTPStatus

from commons.data_access_layer.cosmos_db import CosmosDBDao, CosmosDBRepository, CustomError, current_datetime_str, \
CosmosDBModel
CosmosDBModel, get_date_range_of_month, get_current_year, get_current_month
from commons.data_access_layer.database import EventContext
from time_tracker_api.database import CRUDDao, APICosmosDBDao
from time_tracker_api.security import current_user_id
Expand Down Expand Up @@ -83,6 +83,32 @@ def create_sql_ignore_id_condition(id: str):
else:
return "AND c.id!=@ignore_id"

@staticmethod
def create_sql_date_range_filter(date_range: dict) -> str:
if 'start_date' and 'end_date' in date_range:
return """
((c.start_date BETWEEN @start_date AND @end_date) OR
(c.end_date BETWEEN @start_date AND @end_date))
"""
else:
return ''

def find_all(self, event_context: EventContext, conditions: dict, date_range: dict):
custom_sql_conditions = []
custom_sql_conditions.append(
self.create_sql_date_range_filter(date_range)
)

custom_params = self.generate_params(date_range)

return CosmosDBRepository.find_all(
self,
event_context=event_context,
conditions=conditions,
custom_sql_conditions=custom_sql_conditions,
custom_params=custom_params
)

def on_create(self, new_item_data: dict, event_context: EventContext):
CosmosDBRepository.on_create(self, new_item_data, event_context)

Expand All @@ -107,7 +133,7 @@ def find_interception_with_date_range(self, start_date, end_date, owner_id, tena
{"name": "@end_date", "value": end_date or current_datetime_str()},
{"name": "@ignore_id", "value": ignore_id},
]
params.extend(self.generate_condition_values(conditions))
params.extend(self.generate_params(conditions))
result = self.container.query_items(
query="""
SELECT * FROM c WHERE ((c.start_date BETWEEN @start_date AND @end_date)
Expand Down Expand Up @@ -138,7 +164,7 @@ def find_running(self, tenant_id: str, owner_id: str, mapper: Callable = None):
visibility_condition=self.create_sql_condition_for_visibility(True),
conditions_clause=self.create_sql_where_conditions(conditions),
),
parameters=self.generate_condition_values(conditions),
parameters=self.generate_params(conditions),
partition_key=tenant_id,
max_item_count=1)

Expand Down Expand Up @@ -194,7 +220,11 @@ def checks_owner_and_is_not_started(cls, data: dict):
def get_all(self, conditions: dict = {}) -> list:
event_ctx = self.create_event_context("read-many")
conditions.update({"owner_id": event_ctx.user_id})
return self.repository.find_all(event_ctx, conditions=conditions)

date_range = self.handle_date_filter_args(args=conditions)
return self.repository.find_all(event_ctx,
conditions=conditions,
date_range=date_range)

def get(self, id):
event_ctx = self.create_event_context("read")
Expand Down Expand Up @@ -229,6 +259,22 @@ def find_running(self):
event_ctx = self.create_event_context("find_running")
return self.repository.find_running(event_ctx.tenant_id, event_ctx.user_id)

@staticmethod
def handle_date_filter_args(args: dict) -> dict:
if 'month' and 'year' in args:
month = int(args.get("month"))
year = int(args.get("year"))
args.pop('month')
args.pop('year')
elif 'month' in args:
month = int(args.get("month"))
year = get_current_year()
args.pop('month')
else:
month = get_current_month()
year = get_current_year()
return get_date_range_of_month(year, month)


def create_dao() -> TimeEntriesDao:
repository = TimeEntryCosmosDBRepository()
Expand Down
10 changes: 10 additions & 0 deletions time_tracker_api/time_entries/time_entries_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,20 @@
"uri",
])

# custom attributes filter
attributes_filter.add_argument('month', required=False,
store_missing=False,
help="(Filter) Month to filter by",
location='args')
attributes_filter.add_argument('year', required=False,
store_missing=False,
help="(Filter) Year to filter by",
location='args')

@ns.route('')
class TimeEntries(Resource):
@ns.doc('list_time_entries')
@ns.expect(attributes_filter)
@ns.marshal_list_with(time_entry)
def get(self):
"""List all time entries"""
Expand Down