diff --git a/commons/data_access_layer/cosmos_db.py b/commons/data_access_layer/cosmos_db.py index 8ca6d8fc..8082bd9a 100644 --- a/commons/data_access_layer/cosmos_db.py +++ b/commons/data_access_layer/cosmos_db.py @@ -26,24 +26,35 @@ def __init__(self, client, db_id: str, logger=None): # pragma: no cover def from_flask_config(cls, app: Flask): db_uri = app.config.get('COSMOS_DATABASE_URI') if db_uri is None: - app.logger.warn("COSMOS_DATABASE_URI was not found. Looking for alternative variables.") + app.logger.warn( + "COSMOS_DATABASE_URI was not found. Looking for alternative variables." + ) account_uri = app.config.get('DATABASE_ACCOUNT_URI') if account_uri is None: - raise EnvironmentError("DATABASE_ACCOUNT_URI is not defined in the environment") + raise EnvironmentError( + "DATABASE_ACCOUNT_URI is not defined in the environment" + ) master_key = app.config.get('DATABASE_MASTER_KEY') if master_key is None: - raise EnvironmentError("DATABASE_MASTER_KEY is not defined in the environment") - - client = cosmos_client.CosmosClient(account_uri, {'masterKey': master_key}, - user_agent="TimeTrackerAPI", - user_agent_overwrite=True) + raise EnvironmentError( + "DATABASE_MASTER_KEY is not defined in the environment" + ) + + client = cosmos_client.CosmosClient( + account_uri, + {'masterKey': master_key}, + user_agent="TimeTrackerAPI", + user_agent_overwrite=True, + ) else: client = cosmos_client.CosmosClient.from_connection_string(db_uri) db_id = app.config.get('DATABASE_NAME') if db_id is None: - raise EnvironmentError("DATABASE_NAME is not defined in the environment") + raise EnvironmentError( + "DATABASE_NAME is not defined in the environment" + ) return cls(client, db_id, logger=app.logger) @@ -60,7 +71,7 @@ def delete_container(self, container_id: str): cosmos_helper: CosmosDBFacade = None -class CosmosDBModel(): +class CosmosDBModel: def __init__(self, data): names = set([f.name for f in dataclasses.fields(self)]) for k, v in data.items(): @@ -73,45 +84,63 @@ def partition_key_attribute(pk: PartitionKey) -> str: class CosmosDBRepository: - def __init__(self, container_id: str, - partition_key_attribute: str, - mapper: Callable = None, - order_fields: list = [], - custom_cosmos_helper: CosmosDBFacade = None): + def __init__( + self, + container_id: str, + partition_key_attribute: str, + mapper: Callable = None, + order_fields: list = [], + custom_cosmos_helper: CosmosDBFacade = None, + ): global cosmos_helper self.cosmos_helper = custom_cosmos_helper or cosmos_helper if self.cosmos_helper is None: # pragma: no cover raise ValueError("The cosmos_db module has not been initialized!") self.mapper = mapper self.order_fields = order_fields - self.container: ContainerProxy = self.cosmos_helper.db.get_container_client(container_id) - self.partition_key_attribute: str = partition_key_attribute + self.container: ContainerProxy = self.cosmos_helper.db.get_container_client( + container_id + ) + self.partition_key_attribute = partition_key_attribute @classmethod - def from_definition(cls, container_definition: dict, - mapper: Callable = None, - custom_cosmos_helper: CosmosDBFacade = None): - pk_attrib = partition_key_attribute(container_definition['partition_key']) - return cls(container_definition['id'], pk_attrib, - mapper=mapper, - custom_cosmos_helper=custom_cosmos_helper) + def from_definition( + cls, + container_definition: dict, + mapper: Callable = None, + custom_cosmos_helper: CosmosDBFacade = None, + ): + pk_attrib = partition_key_attribute( + container_definition['partition_key'] + ) + return cls( + container_definition['id'], + pk_attrib, + mapper=mapper, + custom_cosmos_helper=custom_cosmos_helper, + ) @staticmethod - def create_sql_condition_for_visibility(visible_only: bool, container_name='c') -> str: + def create_sql_condition_for_visibility( + visible_only: bool, container_name='c' + ) -> str: if visible_only: # We are considering that `deleted == null` is not a choice return 'AND NOT IS_DEFINED(%s.deleted)' % container_name return '' @staticmethod - def create_sql_where_conditions(conditions: dict, container_name='c') -> str: + def create_sql_where_conditions( + conditions: dict, container_name='c' + ) -> str: where_conditions = [] for k in conditions.keys(): where_conditions.append(f'{container_name}.{k} = @{k}') if len(where_conditions) > 0: return "AND {where_conditions_clause}".format( - where_conditions_clause=" AND ".join(where_conditions)) + where_conditions_clause=" AND ".join(where_conditions) + ) else: return "" @@ -119,7 +148,10 @@ def create_sql_where_conditions(conditions: dict, container_name='c') -> str: def create_custom_sql_conditions(custom_sql_conditions: List[str]) -> str: if len(custom_sql_conditions) > 0: return "AND {custom_sql_conditions_clause}".format( - custom_sql_conditions_clause=" AND ".join(custom_sql_conditions)) + custom_sql_conditions_clause=" AND ".join( + custom_sql_conditions + ) + ) else: return '' @@ -127,18 +159,16 @@ def create_custom_sql_conditions(custom_sql_conditions: List[str]) -> str: def generate_params(conditions: dict) -> dict: result = [] for k, v in conditions.items(): - result.append({ - "name": "@%s" % k, - "value": v - }) + result.append({"name": "@%s" % k, "value": v}) return result @staticmethod def check_visibility(item, throw_not_found_if_deleted): if throw_not_found_if_deleted and item.get('deleted') is not None: - raise exceptions.CosmosResourceNotFoundError(message='Deleted item', - status_code=404) + raise exceptions.CosmosResourceNotFoundError( + message='Deleted item', status_code=404 + ) return item @staticmethod @@ -158,14 +188,32 @@ def attach_context(data: dict, event_context: EventContext): "session_id": event_context.session_id, } - def create(self, data: dict, event_context: EventContext, mapper: Callable = None): + def create( + self, data: dict, event_context: EventContext, mapper: Callable = None + ): self.on_create(data, event_context) function_mapper = self.get_mapper_or_dict(mapper) self.attach_context(data, event_context) return function_mapper(self.container.create_item(body=data)) - def find(self, id: str, event_context: EventContext, peeker: 'function' = None, - visible_only=True, mapper: Callable = None): + def on_create(self, new_item_data: dict, event_context: EventContext): + if new_item_data.get('id') is None: + new_item_data['id'] = generate_uuid4() + + new_item_data[ + self.partition_key_attribute + ] = self.find_partition_key_value(event_context) + + self.replace_empty_value_per_none(new_item_data) + + def find( + self, + id: str, + event_context: EventContext, + peeker: 'function' = None, + visible_only=True, + mapper: Callable = None, + ): partition_key_value = self.find_partition_key_value(event_context) found_item = self.container.read_item(id, partition_key_value) if peeker: @@ -174,8 +222,17 @@ def find(self, id: str, event_context: EventContext, peeker: 'function' = None, function_mapper = self.get_mapper_or_dict(mapper) return function_mapper(self.check_visibility(found_item, visible_only)) - def find_all(self, event_context: EventContext, conditions: dict = {}, custom_sql_conditions: List[str] = [], - custom_params: dict = {}, max_count=None, offset=0, visible_only=True, mapper: Callable = None): + def find_all( + self, + event_context: EventContext, + conditions: dict = {}, + custom_sql_conditions: List[str] = [], + custom_params: dict = {}, + max_count=None, + offset=0, + visible_only=True, + mapper: Callable = None, + ): partition_key_value = self.find_partition_key_value(event_context) max_count = self.get_page_size_or(max_count) params = [ @@ -185,40 +242,80 @@ def find_all(self, event_context: EventContext, conditions: dict = {}, custom_sq ] params.extend(self.generate_params(conditions)) params.extend(custom_params) - result = self.container.query_items(query=""" - SELECT * FROM c WHERE c.{partition_key_attribute}=@partition_key_value - {conditions_clause} {visibility_condition} {custom_sql_conditions_clause} {order_clause} + result = self.container.query_items( + query=""" + SELECT * FROM c + WHERE c.{partition_key_attribute}=@partition_key_value + {conditions_clause} + {visibility_condition} + {custom_sql_conditions_clause} + {order_clause} OFFSET @offset LIMIT @max_count - """.format(partition_key_attribute=self.partition_key_attribute, - visibility_condition=self.create_sql_condition_for_visibility(visible_only), - conditions_clause=self.create_sql_where_conditions(conditions), - custom_sql_conditions_clause=self.create_custom_sql_conditions(custom_sql_conditions), - order_clause=self.create_sql_order_clause()), - parameters=params, - partition_key=partition_key_value, - max_item_count=max_count) + """.format( + partition_key_attribute=self.partition_key_attribute, + visibility_condition=self.create_sql_condition_for_visibility( + visible_only + ), + conditions_clause=self.create_sql_where_conditions(conditions), + custom_sql_conditions_clause=self.create_custom_sql_conditions( + custom_sql_conditions + ), + order_clause=self.create_sql_order_clause(), + ), + parameters=params, + partition_key=partition_key_value, + max_item_count=max_count, + ) function_mapper = self.get_mapper_or_dict(mapper) return list(map(function_mapper, result)) - def partial_update(self, id: str, changes: dict, event_context: EventContext, - peeker: 'function' = None, visible_only=True, mapper: Callable = None): - item_data = self.find(id, event_context, peeker=peeker, visible_only=visible_only, mapper=dict) + def partial_update( + self, + id: str, + changes: dict, + event_context: EventContext, + peeker: 'function' = None, + visible_only=True, + mapper: Callable = None, + ): + item_data = self.find( + id, + event_context, + peeker=peeker, + visible_only=visible_only, + mapper=dict, + ) item_data.update(changes) return self.update(id, item_data, event_context, mapper=mapper) - def update(self, id: str, item_data: dict, event_context: EventContext, - mapper: Callable = None): + def update( + self, + id: str, + item_data: dict, + event_context: EventContext, + mapper: Callable = None, + ): self.on_update(item_data, event_context) function_mapper = self.get_mapper_or_dict(mapper) self.attach_context(item_data, event_context) return function_mapper(self.container.replace_item(id, body=item_data)) - def delete(self, id: str, event_context: EventContext, - peeker: 'function' = None, mapper: Callable = None): - return self.partial_update(id, { - 'deleted': generate_uuid4() - }, event_context, peeker=peeker, visible_only=True, mapper=mapper) + def delete( + self, + id: str, + event_context: EventContext, + peeker: 'function' = None, + mapper: Callable = None, + ): + return self.partial_update( + id, + {'deleted': generate_uuid4()}, + event_context, + peeker=peeker, + visible_only=True, + mapper=mapper, + ) def delete_permanently(self, id: str, event_context: EventContext) -> None: partition_key_value = self.find_partition_key_value(event_context) @@ -235,14 +332,6 @@ def get_page_size_or(self, custom_page_size: int) -> int: # or any other repository for the settings return custom_page_size or 100 - def on_create(self, new_item_data: dict, event_context: EventContext): - if new_item_data.get('id') is None: - new_item_data['id'] = generate_uuid4() - - new_item_data[self.partition_key_attribute] = self.find_partition_key_value(event_context) - - self.replace_empty_value_per_none(new_item_data) - def on_update(self, update_item_data: dict, event_context: EventContext): pass @@ -276,9 +365,12 @@ def delete(self, id): event_ctx = self.create_event_context("delete") self.repository.delete(id, event_ctx) - def create_event_context(self, action: str = None, description: str = None): - return EventContext(self.repository.container.id, action, - description=description) + def create_event_context( + self, action: str = None, description: str = None + ): + return EventContext( + self.repository.container.id, action, description=description + ) class CustomError(HTTPException): @@ -324,10 +416,7 @@ def get_current_month() -> int: return datetime.now().month -def get_date_range_of_month( - year: int, - month: int -) -> Dict[str, str]: +def get_date_range_of_month(year: int, month: int) -> Dict[str, str]: first_day_of_month = 1 start_date = datetime(year=year, month=month, day=first_day_of_month) @@ -339,10 +428,10 @@ def get_date_range_of_month( hour=23, minute=59, second=59, - microsecond=999999 + microsecond=999999, ) return { 'start_date': datetime_str(start_date), - 'end_date': datetime_str(end_date) + 'end_date': datetime_str(end_date), } diff --git a/time_tracker_api/config.py b/time_tracker_api/config.py index e31d9ab2..49a6afdf 100644 --- a/time_tracker_api/config.py +++ b/time_tracker_api/config.py @@ -12,6 +12,7 @@ class Config: RESTPLUS_VALIDATE = True DEBUG = True CORS_ORIGINS = "*" + ERROR_404_HELP = False class DevelopmentConfig(Config): @@ -49,8 +50,13 @@ class ProductionConfig(Config): class AzureConfig(CosmosDB): - SQL_DATABASE_URI = os.environ.get('SQL_DATABASE_URI', os.environ.get('SQLCONNSTR_DATABASE_URI')) - COSMOS_DATABASE_URI = os.environ.get('COSMOS_DATABASE_URI', os.environ.get('CUSTOMCONNSTR_COSMOS_DATABASE_URI')) + SQL_DATABASE_URI = os.environ.get( + 'SQL_DATABASE_URI', os.environ.get('SQLCONNSTR_DATABASE_URI') + ) + COSMOS_DATABASE_URI = os.environ.get( + 'COSMOS_DATABASE_URI', + os.environ.get('CUSTOMCONNSTR_COSMOS_DATABASE_URI'), + ) SQLALCHEMY_DATABASE_URI = SQL_DATABASE_URI diff --git a/time_tracker_api/time_entries/time_entries_model.py b/time_tracker_api/time_entries/time_entries_model.py index 10cd7989..31044f55 100644 --- a/time_tracker_api/time_entries/time_entries_model.py +++ b/time_tracker_api/time_entries/time_entries_model.py @@ -5,8 +5,16 @@ from azure.cosmos import PartitionKey from flask_restplus._http import HTTPStatus -from commons.data_access_layer.cosmos_db import CosmosDBDao, CosmosDBRepository, CustomError, current_datetime_str, \ - CosmosDBModel, get_date_range_of_month, get_current_year, get_current_month +from commons.data_access_layer.cosmos_db import ( + CosmosDBDao, + CosmosDBRepository, + CustomError, + current_datetime_str, + CosmosDBModel, + get_date_range_of_month, + get_current_year, + get_current_month, +) from commons.data_access_layer.database import EventContext from time_tracker_api.database import CRUDDao, APICosmosDBDao from time_tracker_api.security import current_user_id @@ -34,10 +42,8 @@ def restart(self, id: str): 'id': 'time_entry', 'partition_key': PartitionKey(path='/tenant_id'), 'unique_key_policy': { - 'uniqueKeys': [ - {'paths': ['/owner_id', '/end_date', '/deleted']}, - ] - } + 'uniqueKeys': [{'paths': ['/owner_id', '/end_date', '/deleted']}] + }, } @@ -66,15 +72,20 @@ def __repr__(self): return '