Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export FLASK_APP=time_tracker_api
## For Azure Cosmos DB
export DATABASE_ACCOUNT_URI=https://<project_db_name>.documents.azure.com:443
export DATABASE_MASTER_KEY=<db_master_key>
export DATABASE_NAME=<db_name>
### or
# export COSMOS_DATABASE_URI=AccountEndpoint=<ACCOUNT_URI>;AccountKey=<ACCOUNT_KEY>
## Also specify the database name
export DATABASE_NAME=<db_name>
87 changes: 50 additions & 37 deletions commons/data_access_layer/cosmos_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from werkzeug.exceptions import HTTPException

from commons.data_access_layer.database import CRUDDao
from time_tracker_api.security import current_user_tenant_id, current_user_id
from time_tracker_api.security import current_user_tenant_id


class CosmosDBFacade:
Expand Down Expand Up @@ -105,10 +105,26 @@ def create_sql_condition_for_visibility(visible_only: bool, container_name='c')
return ''

@staticmethod
def create_sql_condition_for_owner_id(owner_id: str, container_name='c') -> str:
if owner_id:
return 'AND %s.owner_id=@owner_id' % container_name
return ''
def create_sql_where_conditions(conditions: dict, container_name='c') -> str:
where_conditions = []
for k in conditions.keys():
where_conditions.append('{c}.{var} = @{var}'.format(c=container_name, var=k))

if len(where_conditions) > 0:
return "AND {where_conditions_clause}".format(
where_conditions_clause=" AND ".join(where_conditions))
else:
return ""

@staticmethod
def append_conditions_values(params: list, conditions: dict) -> dict:
for k, v in conditions.items():
params.append({
"name": "@%s" % k,
"value": v
})

return params

@staticmethod
def check_visibility(item, throw_not_found_if_deleted):
Expand All @@ -123,39 +139,41 @@ def create(self, data: dict, mapper: Callable = None):
function_mapper = self.get_mapper_or_dict(mapper)
return function_mapper(self.container.create_item(body=data))

def find(self, id: str, partition_key_value, visible_only=True, mapper: Callable = None):
def find(self, id: str, partition_key_value, peeker: 'function' = None, visible_only=True, mapper: Callable = None):
found_item = self.container.read_item(id, partition_key_value)
if peeker:
peeker(found_item)

function_mapper = self.get_mapper_or_dict(mapper)
return function_mapper(self.check_visibility(found_item, visible_only))

def find_all(self, partition_key_value: str, owner_id=None, max_count=None, offset=0,
def find_all(self, partition_key_value: str, conditions: dict = {}, max_count=None, offset=0,
visible_only=True, mapper: Callable = None):
# TODO Use the tenant_id param and change container alias
max_count = self.get_page_size_or(max_count)
result = self.container.query_items(
query="""
params = self.append_conditions_values([
{"name": "@partition_key_value", "value": partition_key_value},
{"name": "@offset", "value": offset},
{"name": "@max_count", "value": max_count},
], conditions)
result = self.container.query_items(query="""
SELECT * FROM c WHERE c.{partition_key_attribute}=@partition_key_value
{owner_condition} {visibility_condition} {order_clause}
{conditions_clause} {visibility_condition} {order_clause}
OFFSET @offset LIMIT @max_count
""".format(partition_key_attribute=self.partition_key_attribute,
visibility_condition=self.create_sql_condition_for_visibility(visible_only),
owner_condition=self.create_sql_condition_for_owner_id(owner_id),
conditions_clause=self.create_sql_where_conditions(conditions),
order_clause=self.create_sql_order_clause()),
parameters=[
{"name": "@partition_key_value", "value": partition_key_value},
{"name": "@offset", "value": offset},
{"name": "@max_count", "value": max_count},
{"name": "@owner_id", "value": owner_id},
],
partition_key=partition_key_value,
max_item_count=max_count)
parameters=params,
partition_key=partition_key_value,
max_item_count=max_count)

function_mapper = self.get_mapper_or_dict(mapper)
return list(map(function_mapper, result))

def partial_update(self, id: str, changes: dict, partition_key_value: str,
visible_only=True, mapper: Callable = None):
item_data = self.find(id, partition_key_value, visible_only=visible_only, mapper=dict)
peeker: 'function' = None, visible_only=True, mapper: Callable = None):
item_data = self.find(id, partition_key_value, peeker=peeker, visible_only=visible_only, mapper=dict)
item_data.update(changes)
return self.update(id, item_data, mapper=mapper)

Expand All @@ -164,10 +182,11 @@ def update(self, id: str, item_data: dict, mapper: Callable = None):
function_mapper = self.get_mapper_or_dict(mapper)
return function_mapper(self.container.replace_item(id, body=item_data))

def delete(self, id: str, partition_key_value: str, mapper: Callable = None):
def delete(self, id: str, partition_key_value: str,
peeker: 'function' = None, mapper: Callable = None):
return self.partial_update(id, {
'deleted': str(uuid.uuid4())
}, partition_key_value, visible_only=True, mapper=mapper)
}, partition_key_value, peeker=peeker, visible_only=True, mapper=mapper)

def delete_permanently(self, id: str, partition_key_value: str) -> None:
self.container.delete_item(id, partition_key_value)
Expand All @@ -190,30 +209,21 @@ def on_update(self, update_item_data: dict):
def create_sql_order_clause(self):
if len(self.order_fields) > 0:
return "ORDER BY c.{}".format(", c.".join(self.order_fields))
else:
return ""
return ""


class CosmosDBDao(CRUDDao):
def __init__(self, repository: CosmosDBRepository):
self.repository = repository

@property
def partition_key_value(self):
return current_user_tenant_id()

def get_all(self) -> list:
tenant_id: str = self.partition_key_value
owner_id = current_user_id()
return self.repository.find_all(partition_key_value=tenant_id, owner_id=owner_id)
return self.repository.find_all(partition_key_value=self.partition_key_value)

def get(self, id):
tenant_id: str = self.partition_key_value
return self.repository.find(id, partition_key_value=tenant_id)
return self.repository.find(id, partition_key_value=self.partition_key_value)

def create(self, data: dict):
data[self.repository.partition_key_attribute] = self.partition_key_value
data['owner_id'] = current_user_id()
return self.repository.create(data)

def update(self, id, data: dict):
Expand All @@ -222,8 +232,11 @@ def update(self, id, data: dict):
partition_key_value=self.partition_key_value)

def delete(self, id):
tenant_id: str = current_user_tenant_id()
self.repository.delete(id, partition_key_value=tenant_id)
self.repository.delete(id, partition_key_value=self.partition_key_value)

@property
def partition_key_value(self):
return current_user_tenant_id()


class CustomError(HTTPException):
Expand Down
54 changes: 52 additions & 2 deletions tests/commons/data_access_layer/cosmos_db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import pytest
from azure.cosmos.exceptions import CosmosResourceExistsError, CosmosResourceNotFoundError
from faker import Faker
from flask_restplus._http import HTTPStatus
from pytest import fail

from commons.data_access_layer.cosmos_db import CosmosDBRepository, CosmosDBModel
from commons.data_access_layer.cosmos_db import CosmosDBRepository, CosmosDBModel, CustomError

fake = Faker()
Faker.seed()
Expand Down Expand Up @@ -459,7 +460,7 @@ def test_find_all_can_find_deleted_items_only_if_visibile_only_is_true(
assert deleted_item is not None
assert deleted_item['deleted'] is not None

visible_items = cosmos_db_repository.find_all(sample_item['tenant_id'])
visible_items = cosmos_db_repository.find_all(partition_key_value=sample_item.get('tenant_id'))

assert visible_items is not None
assert any(item['id'] == sample_item['id'] for item in visible_items) == False, \
Expand Down Expand Up @@ -536,3 +537,52 @@ def test_delete_permanently_with_valid_id_should_succeed(
except Exception as e:
assert type(e) is CosmosResourceNotFoundError
assert e.status_code == 404


def test_repository_create_sql_where_conditions_with_multiple_values(cosmos_db_repository: CosmosDBRepository):
result = cosmos_db_repository.create_sql_where_conditions({
'owner_id': 'mark',
'customer_id': 'me'
}, "c")

assert result is not None
assert result == "AND c.owner_id = @owner_id AND c.customer_id = @customer_id"


def test_repository_create_sql_where_conditions_with_no_values(cosmos_db_repository: CosmosDBRepository):
result = cosmos_db_repository.create_sql_where_conditions({}, "c")

assert result is not None
assert result == ""


def test_repository_append_conditions_values(cosmos_db_repository: CosmosDBRepository):
result = cosmos_db_repository.append_conditions_values([], {'owner_id': 'mark', 'customer_id': 'ioet'})

assert result is not None
assert result == [{'name': '@owner_id', 'value': 'mark'},
{'name': '@customer_id', 'value': 'ioet'}]


def test_find_should_call_picker_if_it_was_specified(cosmos_db_repository: CosmosDBRepository,
sample_item: dict,
another_item: dict):
def raise_bad_request_if_name_diff_the_one_from_sample_item(data: dict):
if sample_item['name'] != data['name']:
raise CustomError(HTTPStatus.BAD_REQUEST, "Anything")

found_item = cosmos_db_repository.find(sample_item['id'],
partition_key_value=sample_item['tenant_id'])

assert found_item is not None
assert found_item['id'] == sample_item['id']

try:
cosmos_db_repository.find(another_item['id'],
partition_key_value=another_item['tenant_id'],
peeker=raise_bad_request_if_name_diff_the_one_from_sample_item)

fail('It should have not found any item because of condition')
except Exception as e:
assert e.code == HTTPStatus.BAD_REQUEST
assert e.description == "Anything"
13 changes: 12 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ def sample_item(cosmos_db_repository: CosmosDBRepository, tenant_id: str) -> dic
return cosmos_db_repository.create(sample_item_data)


@pytest.fixture(scope="function")
def another_item(cosmos_db_repository: CosmosDBRepository, tenant_id: str) -> dict:
sample_item_data = dict(id=fake.uuid4(),
name=fake.name(),
email=fake.safe_email(),
age=fake.pyint(min_value=10, max_value=80),
tenant_id=tenant_id)

return cosmos_db_repository.create(sample_item_data)


@pytest.yield_fixture(scope="module")
def time_entry_repository(cosmos_db_repository: CosmosDBRepository) -> TimeEntryCosmosDBRepository:
def time_entry_repository() -> TimeEntryCosmosDBRepository:
return TimeEntryCosmosDBRepository()
12 changes: 6 additions & 6 deletions tests/time_tracker_api/time_entries/time_entries_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_find_interception_with_date_range_should_find(start_date: datetime,
finally:
time_entry_repository.delete_permanently(existing_item.id, partition_key_value=existing_item.tenant_id)


def test_find_interception_should_ignore_id_of_existing_item(owner_id: str,
tenant_id: str,
time_entry_repository: TimeEntryCosmosDBRepository):
Expand All @@ -62,13 +63,13 @@ def test_find_interception_should_ignore_id_of_existing_item(owner_id: str,
try:

colliding_result = time_entry_repository.find_interception_with_date_range(start_date, end_date,
owner_id=owner_id,
partition_key_value=tenant_id)
owner_id=owner_id,
partition_key_value=tenant_id)

non_colliding_result = time_entry_repository.find_interception_with_date_range(start_date, end_date,
owner_id=owner_id,
partition_key_value=tenant_id,
ignore_id=existing_item.id)
owner_id=owner_id,
partition_key_value=tenant_id,
ignore_id=existing_item.id)

colliding_result is not None
assert any([existing_item.id == item.id for item in colliding_result])
Expand All @@ -77,4 +78,3 @@ def test_find_interception_should_ignore_id_of_existing_item(owner_id: str,
assert not any([existing_item.id == item.id for item in non_colliding_result])
finally:
time_entry_repository.delete_permanently(existing_item.id, partition_key_value=existing_item.tenant_id)

Loading