Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 165 additions & 76 deletions commons/data_access_layer/cosmos_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,35 @@ def __init__(self, client, db_id: str, logger=None): # pragma: no cover
def from_flask_config(cls, app: Flask):
db_uri = app.config.get('COSMOS_DATABASE_URI')
if db_uri is None:
app.logger.warn("COSMOS_DATABASE_URI was not found. Looking for alternative variables.")
app.logger.warn(
"COSMOS_DATABASE_URI was not found. Looking for alternative variables."
)
account_uri = app.config.get('DATABASE_ACCOUNT_URI')
if account_uri is None:
raise EnvironmentError("DATABASE_ACCOUNT_URI is not defined in the environment")
raise EnvironmentError(
"DATABASE_ACCOUNT_URI is not defined in the environment"
)

master_key = app.config.get('DATABASE_MASTER_KEY')
if master_key is None:
raise EnvironmentError("DATABASE_MASTER_KEY is not defined in the environment")

client = cosmos_client.CosmosClient(account_uri, {'masterKey': master_key},
user_agent="TimeTrackerAPI",
user_agent_overwrite=True)
raise EnvironmentError(
"DATABASE_MASTER_KEY is not defined in the environment"
)

client = cosmos_client.CosmosClient(
account_uri,
{'masterKey': master_key},
user_agent="TimeTrackerAPI",
user_agent_overwrite=True,
)
else:
client = cosmos_client.CosmosClient.from_connection_string(db_uri)

db_id = app.config.get('DATABASE_NAME')
if db_id is None:
raise EnvironmentError("DATABASE_NAME is not defined in the environment")
raise EnvironmentError(
"DATABASE_NAME is not defined in the environment"
)

return cls(client, db_id, logger=app.logger)

Expand All @@ -60,7 +71,7 @@ def delete_container(self, container_id: str):
cosmos_helper: CosmosDBFacade = None


class CosmosDBModel():
class CosmosDBModel:
def __init__(self, data):
names = set([f.name for f in dataclasses.fields(self)])
for k, v in data.items():
Expand All @@ -73,72 +84,91 @@ def partition_key_attribute(pk: PartitionKey) -> str:


class CosmosDBRepository:
def __init__(self, container_id: str,
partition_key_attribute: str,
mapper: Callable = None,
order_fields: list = [],
custom_cosmos_helper: CosmosDBFacade = None):
def __init__(
self,
container_id: str,
partition_key_attribute: str,
mapper: Callable = None,
order_fields: list = [],
custom_cosmos_helper: CosmosDBFacade = None,
):
global cosmos_helper
self.cosmos_helper = custom_cosmos_helper or cosmos_helper
if self.cosmos_helper is None: # pragma: no cover
raise ValueError("The cosmos_db module has not been initialized!")
self.mapper = mapper
self.order_fields = order_fields
self.container: ContainerProxy = self.cosmos_helper.db.get_container_client(container_id)
self.partition_key_attribute: str = partition_key_attribute
self.container: ContainerProxy = self.cosmos_helper.db.get_container_client(
container_id
)
self.partition_key_attribute = partition_key_attribute

@classmethod
def from_definition(cls, container_definition: dict,
mapper: Callable = None,
custom_cosmos_helper: CosmosDBFacade = None):
pk_attrib = partition_key_attribute(container_definition['partition_key'])
return cls(container_definition['id'], pk_attrib,
mapper=mapper,
custom_cosmos_helper=custom_cosmos_helper)
def from_definition(
cls,
container_definition: dict,
mapper: Callable = None,
custom_cosmos_helper: CosmosDBFacade = None,
):
pk_attrib = partition_key_attribute(
container_definition['partition_key']
)
return cls(
container_definition['id'],
pk_attrib,
mapper=mapper,
custom_cosmos_helper=custom_cosmos_helper,
)

@staticmethod
def create_sql_condition_for_visibility(visible_only: bool, container_name='c') -> str:
def create_sql_condition_for_visibility(
visible_only: bool, container_name='c'
) -> str:
if visible_only:
# We are considering that `deleted == null` is not a choice
return 'AND NOT IS_DEFINED(%s.deleted)' % container_name
return ''

@staticmethod
def create_sql_where_conditions(conditions: dict, container_name='c') -> str:
def create_sql_where_conditions(
conditions: dict, container_name='c'
) -> str:
where_conditions = []
for k in conditions.keys():
where_conditions.append(f'{container_name}.{k} = @{k}')

if len(where_conditions) > 0:
return "AND {where_conditions_clause}".format(
where_conditions_clause=" AND ".join(where_conditions))
where_conditions_clause=" AND ".join(where_conditions)
)
else:
return ""

@staticmethod
def create_custom_sql_conditions(custom_sql_conditions: List[str]) -> str:
if len(custom_sql_conditions) > 0:
return "AND {custom_sql_conditions_clause}".format(
custom_sql_conditions_clause=" AND ".join(custom_sql_conditions))
custom_sql_conditions_clause=" AND ".join(
custom_sql_conditions
)
)
else:
return ''

@staticmethod
def generate_params(conditions: dict) -> dict:
result = []
for k, v in conditions.items():
result.append({
"name": "@%s" % k,
"value": v
})
result.append({"name": "@%s" % k, "value": v})

return result

@staticmethod
def check_visibility(item, throw_not_found_if_deleted):
if throw_not_found_if_deleted and item.get('deleted') is not None:
raise exceptions.CosmosResourceNotFoundError(message='Deleted item',
status_code=404)
raise exceptions.CosmosResourceNotFoundError(
message='Deleted item', status_code=404
)
return item

@staticmethod
Expand All @@ -158,14 +188,32 @@ def attach_context(data: dict, event_context: EventContext):
"session_id": event_context.session_id,
}

def create(self, data: dict, event_context: EventContext, mapper: Callable = None):
def create(
self, data: dict, event_context: EventContext, mapper: Callable = None
):
self.on_create(data, event_context)
function_mapper = self.get_mapper_or_dict(mapper)
self.attach_context(data, event_context)
return function_mapper(self.container.create_item(body=data))

def find(self, id: str, event_context: EventContext, peeker: 'function' = None,
visible_only=True, mapper: Callable = None):
def on_create(self, new_item_data: dict, event_context: EventContext):
if new_item_data.get('id') is None:
new_item_data['id'] = generate_uuid4()

new_item_data[
self.partition_key_attribute
] = self.find_partition_key_value(event_context)

self.replace_empty_value_per_none(new_item_data)

def find(
self,
id: str,
event_context: EventContext,
peeker: 'function' = None,
visible_only=True,
mapper: Callable = None,
):
partition_key_value = self.find_partition_key_value(event_context)
found_item = self.container.read_item(id, partition_key_value)
if peeker:
Expand All @@ -174,8 +222,17 @@ def find(self, id: str, event_context: EventContext, peeker: 'function' = None,
function_mapper = self.get_mapper_or_dict(mapper)
return function_mapper(self.check_visibility(found_item, visible_only))

def find_all(self, event_context: EventContext, conditions: dict = {}, custom_sql_conditions: List[str] = [],
custom_params: dict = {}, max_count=None, offset=0, visible_only=True, mapper: Callable = None):
def find_all(
self,
event_context: EventContext,
conditions: dict = {},
custom_sql_conditions: List[str] = [],
custom_params: dict = {},
max_count=None,
offset=0,
visible_only=True,
mapper: Callable = None,
):
partition_key_value = self.find_partition_key_value(event_context)
max_count = self.get_page_size_or(max_count)
params = [
Expand All @@ -185,40 +242,80 @@ def find_all(self, event_context: EventContext, conditions: dict = {}, custom_sq
]
params.extend(self.generate_params(conditions))
params.extend(custom_params)
result = self.container.query_items(query="""
SELECT * FROM c WHERE c.{partition_key_attribute}=@partition_key_value
{conditions_clause} {visibility_condition} {custom_sql_conditions_clause} {order_clause}
result = self.container.query_items(
query="""
SELECT * FROM c
WHERE c.{partition_key_attribute}=@partition_key_value
{conditions_clause}
{visibility_condition}
{custom_sql_conditions_clause}
{order_clause}
OFFSET @offset LIMIT @max_count
""".format(partition_key_attribute=self.partition_key_attribute,
visibility_condition=self.create_sql_condition_for_visibility(visible_only),
conditions_clause=self.create_sql_where_conditions(conditions),
custom_sql_conditions_clause=self.create_custom_sql_conditions(custom_sql_conditions),
order_clause=self.create_sql_order_clause()),
parameters=params,
partition_key=partition_key_value,
max_item_count=max_count)
""".format(
partition_key_attribute=self.partition_key_attribute,
visibility_condition=self.create_sql_condition_for_visibility(
visible_only
),
conditions_clause=self.create_sql_where_conditions(conditions),
custom_sql_conditions_clause=self.create_custom_sql_conditions(
custom_sql_conditions
),
order_clause=self.create_sql_order_clause(),
),
parameters=params,
partition_key=partition_key_value,
max_item_count=max_count,
)

function_mapper = self.get_mapper_or_dict(mapper)
return list(map(function_mapper, result))

def partial_update(self, id: str, changes: dict, event_context: EventContext,
peeker: 'function' = None, visible_only=True, mapper: Callable = None):
item_data = self.find(id, event_context, peeker=peeker, visible_only=visible_only, mapper=dict)
def partial_update(
self,
id: str,
changes: dict,
event_context: EventContext,
peeker: 'function' = None,
visible_only=True,
mapper: Callable = None,
):
item_data = self.find(
id,
event_context,
peeker=peeker,
visible_only=visible_only,
mapper=dict,
)
item_data.update(changes)
return self.update(id, item_data, event_context, mapper=mapper)

def update(self, id: str, item_data: dict, event_context: EventContext,
mapper: Callable = None):
def update(
self,
id: str,
item_data: dict,
event_context: EventContext,
mapper: Callable = None,
):
self.on_update(item_data, event_context)
function_mapper = self.get_mapper_or_dict(mapper)
self.attach_context(item_data, event_context)
return function_mapper(self.container.replace_item(id, body=item_data))

def delete(self, id: str, event_context: EventContext,
peeker: 'function' = None, mapper: Callable = None):
return self.partial_update(id, {
'deleted': generate_uuid4()
}, event_context, peeker=peeker, visible_only=True, mapper=mapper)
def delete(
self,
id: str,
event_context: EventContext,
peeker: 'function' = None,
mapper: Callable = None,
):
return self.partial_update(
id,
{'deleted': generate_uuid4()},
event_context,
peeker=peeker,
visible_only=True,
mapper=mapper,
)

def delete_permanently(self, id: str, event_context: EventContext) -> None:
partition_key_value = self.find_partition_key_value(event_context)
Expand All @@ -235,14 +332,6 @@ def get_page_size_or(self, custom_page_size: int) -> int:
# or any other repository for the settings
return custom_page_size or 100

def on_create(self, new_item_data: dict, event_context: EventContext):
if new_item_data.get('id') is None:
new_item_data['id'] = generate_uuid4()

new_item_data[self.partition_key_attribute] = self.find_partition_key_value(event_context)

self.replace_empty_value_per_none(new_item_data)

def on_update(self, update_item_data: dict, event_context: EventContext):
pass

Expand Down Expand Up @@ -276,9 +365,12 @@ def delete(self, id):
event_ctx = self.create_event_context("delete")
self.repository.delete(id, event_ctx)

def create_event_context(self, action: str = None, description: str = None):
return EventContext(self.repository.container.id, action,
description=description)
def create_event_context(
self, action: str = None, description: str = None
):
return EventContext(
self.repository.container.id, action, description=description
)


class CustomError(HTTPException):
Expand Down Expand Up @@ -324,10 +416,7 @@ def get_current_month() -> int:
return datetime.now().month


def get_date_range_of_month(
year: int,
month: int
) -> Dict[str, str]:
def get_date_range_of_month(year: int, month: int) -> Dict[str, str]:
first_day_of_month = 1
start_date = datetime(year=year, month=month, day=first_day_of_month)

Expand All @@ -339,10 +428,10 @@ def get_date_range_of_month(
hour=23,
minute=59,
second=59,
microsecond=999999
microsecond=999999,
)

return {
'start_date': datetime_str(start_date),
'end_date': datetime_str(end_date)
'end_date': datetime_str(end_date),
}
Loading