diff --git a/commons/data_access_layer/cosmos_db.py b/commons/data_access_layer/cosmos_db.py index cf6aa337..8ac412c6 100644 --- a/commons/data_access_layer/cosmos_db.py +++ b/commons/data_access_layer/cosmos_db.py @@ -222,17 +222,8 @@ def find( function_mapper = self.get_mapper_or_dict(mapper) return function_mapper(self.check_visibility(found_item, visible_only)) - def find_all( - self, - event_context: EventContext, - conditions: dict = {}, - custom_sql_conditions: List[str] = [], - custom_params: dict = {}, - max_count=None, - offset=0, - visible_only=True, - mapper: Callable = None, - ): + def find_all(self, event_context: EventContext, conditions: dict = {}, custom_sql_conditions: List[str] = [], + custom_params: dict = {}, max_count=None, offset=0, visible_only=True, mapper: Callable = None): partition_key_value = self.find_partition_key_value(event_context) max_count = self.get_page_size_or(max_count) params = [ @@ -242,8 +233,7 @@ def find_all( ] params.extend(self.generate_params(conditions)) params.extend(custom_params) - result = self.container.query_items( - query=""" + query_str = """ SELECT * FROM c WHERE c.{partition_key_attribute}=@partition_key_value {conditions_clause} @@ -261,7 +251,10 @@ def find_all( custom_sql_conditions ), order_clause=self.create_sql_order_clause(), - ), + ) + + result = self.container.query_items( + query=query_str, parameters=params, partition_key=partition_key_value, max_item_count=max_count, diff --git a/commons/data_access_layer/database.py b/commons/data_access_layer/database.py index c1497b82..aa895339 100644 --- a/commons/data_access_layer/database.py +++ b/commons/data_access_layer/database.py @@ -33,7 +33,7 @@ def delete(self, id): raise NotImplementedError # pragma: no cover -class EventContext(): +class EventContext: def __init__(self, container_id: str, action: str, description: str = None, user_id: str = None, tenant_id: str = None, session_id: str = None, app_id: str = None): diff --git a/time_tracker_api/database.py b/time_tracker_api/database.py index 5da1ab18..35a58d50 100644 --- a/time_tracker_api/database.py +++ b/time_tracker_api/database.py @@ -11,7 +11,7 @@ from commons.data_access_layer.cosmos_db import CosmosDBDao from commons.data_access_layer.database import EventContext -from time_tracker_api.security import current_user_id, current_user_tenant_id +from time_tracker_api.security import current_user_id, current_user_tenant_id, current_role_user, roles class CRUDDao(abc.ABC): @@ -37,10 +37,11 @@ def delete(self, id): class ApiEventContext(EventContext): - def __init__(self, container_id: str, action: str, description: str = None, - user_id: str = None, tenant_id: str = None, session_id: str = None): + def __init__(self, container_id: str, action: str, description: str = None, user_id: str = None, + tenant_id: str = None, session_id: str = None, user_role: str = None): super(ApiEventContext, self).__init__(container_id, action, description) self._user_id = user_id + self._user_role = user_role self._tenant_id = tenant_id self._session_id = session_id @@ -50,6 +51,10 @@ def user_id(self) -> str: self._user_id = current_user_id() return self._user_id + @property + def user_role(self) -> str: + return self._user_role if self._user_role else current_role_user() + @property def tenant_id(self) -> str: if self._tenant_id is None: @@ -60,6 +65,10 @@ def tenant_id(self) -> str: def session_id(self) -> str: return self._session_id + @property + def is_admin(self): + return True if self.user_role == roles.get("admin").get("name") else False + class APICosmosDBDao(CosmosDBDao): def create_event_context(self, action: str = None, description: str = None): diff --git a/time_tracker_api/projects/projects_model.py b/time_tracker_api/projects/projects_model.py index 447e5c1a..3c4a3aac 100644 --- a/time_tracker_api/projects/projects_model.py +++ b/time_tracker_api/projects/projects_model.py @@ -73,7 +73,11 @@ def get_all(self, conditions: dict = None, **kwargs) -> list: customers_id = [customer.id for customer in customers] conditions = conditions if conditions else {} custom_condition = "c.customer_id IN {}".format(str(tuple(customers_id))) - return self.repository.find_all(event_ctx, conditions, custom_sql_conditions=[custom_condition], **kwargs) + if "custom_sql_conditions" in kwargs: + kwargs["custom_sql_conditions"].append(custom_condition) + else: + kwargs["custom_sql_conditions"] = [custom_condition] + return self.repository.find_all(event_ctx, conditions, **kwargs) def create_dao() -> ProjectDao: diff --git a/time_tracker_api/security.py b/time_tracker_api/security.py index b5919a1d..232c0ff2 100644 --- a/time_tracker_api/security.py +++ b/time_tracker_api/security.py @@ -29,6 +29,11 @@ iss_claim_pattern = re.compile(r"(.*).b2clogin.com/(?P%s)" % UUID_REGEX) +roles = { + "admin": {"name": "time-tracker-admin"}, + "client": {"name": "client-role"} +} + def current_user_id() -> str: oid_claim = get_token_json().get("oid") @@ -38,6 +43,11 @@ def current_user_id() -> str: return oid_claim +def current_role_user() -> str: + role_user = get_token_json().get("extension_role", None) + return role_user if role_user else roles.get("client").get("name") + + def current_user_tenant_id() -> str: iss_claim = get_token_json().get("iss") if iss_claim is None: diff --git a/time_tracker_api/time_entries/time_entries_model.py b/time_tracker_api/time_entries/time_entries_model.py index ec73281a..e2bcdc19 100644 --- a/time_tracker_api/time_entries/time_entries_model.py +++ b/time_tracker_api/time_entries/time_entries_model.py @@ -21,6 +21,7 @@ from time_tracker_api.time_entries.custom_modules.utils import ( add_project_name_to_time_entries, ) +from time_tracker_api.projects.projects_model import ProjectCosmosDBModel, create_dao as project_create_dao from time_tracker_api.projects import projects_model from time_tracker_api.database import CRUDDao, APICosmosDBDao from time_tracker_api.security import current_user_id @@ -74,6 +75,14 @@ def __init__(self, data): # pragma: no cover def running(self): return self.end_date is None + def __add__(self, other): + if type(other) is ProjectCosmosDBModel: + time_entry = self.__class__ + time_entry.project_id = other.__dict__ + return time_entry + else: + raise NotImplementedError + def __repr__(self): return '