@@ -26,24 +26,35 @@ def __init__(self, client, db_id: str, logger=None): # pragma: no cover
2626 def from_flask_config (cls , app : Flask ):
2727 db_uri = app .config .get ('COSMOS_DATABASE_URI' )
2828 if db_uri is None :
29- app .logger .warn ("COSMOS_DATABASE_URI was not found. Looking for alternative variables." )
29+ app .logger .warn (
30+ "COSMOS_DATABASE_URI was not found. Looking for alternative variables."
31+ )
3032 account_uri = app .config .get ('DATABASE_ACCOUNT_URI' )
3133 if account_uri is None :
32- raise EnvironmentError ("DATABASE_ACCOUNT_URI is not defined in the environment" )
34+ raise EnvironmentError (
35+ "DATABASE_ACCOUNT_URI is not defined in the environment"
36+ )
3337
3438 master_key = app .config .get ('DATABASE_MASTER_KEY' )
3539 if master_key is None :
36- raise EnvironmentError ("DATABASE_MASTER_KEY is not defined in the environment" )
37-
38- client = cosmos_client .CosmosClient (account_uri , {'masterKey' : master_key },
39- user_agent = "TimeTrackerAPI" ,
40- user_agent_overwrite = True )
40+ raise EnvironmentError (
41+ "DATABASE_MASTER_KEY is not defined in the environment"
42+ )
43+
44+ client = cosmos_client .CosmosClient (
45+ account_uri ,
46+ {'masterKey' : master_key },
47+ user_agent = "TimeTrackerAPI" ,
48+ user_agent_overwrite = True ,
49+ )
4150 else :
4251 client = cosmos_client .CosmosClient .from_connection_string (db_uri )
4352
4453 db_id = app .config .get ('DATABASE_NAME' )
4554 if db_id is None :
46- raise EnvironmentError ("DATABASE_NAME is not defined in the environment" )
55+ raise EnvironmentError (
56+ "DATABASE_NAME is not defined in the environment"
57+ )
4758
4859 return cls (client , db_id , logger = app .logger )
4960
@@ -60,7 +71,7 @@ def delete_container(self, container_id: str):
6071cosmos_helper : CosmosDBFacade = None
6172
6273
63- class CosmosDBModel () :
74+ class CosmosDBModel :
6475 def __init__ (self , data ):
6576 names = set ([f .name for f in dataclasses .fields (self )])
6677 for k , v in data .items ():
@@ -73,72 +84,91 @@ def partition_key_attribute(pk: PartitionKey) -> str:
7384
7485
7586class CosmosDBRepository :
76- def __init__ (self , container_id : str ,
77- partition_key_attribute : str ,
78- mapper : Callable = None ,
79- order_fields : list = [],
80- custom_cosmos_helper : CosmosDBFacade = None ):
87+ def __init__ (
88+ self ,
89+ container_id : str ,
90+ partition_key_attribute : str ,
91+ mapper : Callable = None ,
92+ order_fields : list = [],
93+ custom_cosmos_helper : CosmosDBFacade = None ,
94+ ):
8195 global cosmos_helper
8296 self .cosmos_helper = custom_cosmos_helper or cosmos_helper
8397 if self .cosmos_helper is None : # pragma: no cover
8498 raise ValueError ("The cosmos_db module has not been initialized!" )
8599 self .mapper = mapper
86100 self .order_fields = order_fields
87- self .container : ContainerProxy = self .cosmos_helper .db .get_container_client (container_id )
88- self .partition_key_attribute : str = partition_key_attribute
101+ self .container : ContainerProxy = self .cosmos_helper .db .get_container_client (
102+ container_id
103+ )
104+ self .partition_key_attribute = partition_key_attribute
89105
90106 @classmethod
91- def from_definition (cls , container_definition : dict ,
92- mapper : Callable = None ,
93- custom_cosmos_helper : CosmosDBFacade = None ):
94- pk_attrib = partition_key_attribute (container_definition ['partition_key' ])
95- return cls (container_definition ['id' ], pk_attrib ,
96- mapper = mapper ,
97- custom_cosmos_helper = custom_cosmos_helper )
107+ def from_definition (
108+ cls ,
109+ container_definition : dict ,
110+ mapper : Callable = None ,
111+ custom_cosmos_helper : CosmosDBFacade = None ,
112+ ):
113+ pk_attrib = partition_key_attribute (
114+ container_definition ['partition_key' ]
115+ )
116+ return cls (
117+ container_definition ['id' ],
118+ pk_attrib ,
119+ mapper = mapper ,
120+ custom_cosmos_helper = custom_cosmos_helper ,
121+ )
98122
99123 @staticmethod
100- def create_sql_condition_for_visibility (visible_only : bool , container_name = 'c' ) -> str :
124+ def create_sql_condition_for_visibility (
125+ visible_only : bool , container_name = 'c'
126+ ) -> str :
101127 if visible_only :
102128 # We are considering that `deleted == null` is not a choice
103129 return 'AND NOT IS_DEFINED(%s.deleted)' % container_name
104130 return ''
105131
106132 @staticmethod
107- def create_sql_where_conditions (conditions : dict , container_name = 'c' ) -> str :
133+ def create_sql_where_conditions (
134+ conditions : dict , container_name = 'c'
135+ ) -> str :
108136 where_conditions = []
109137 for k in conditions .keys ():
110138 where_conditions .append (f'{ container_name } .{ k } = @{ k } ' )
111139
112140 if len (where_conditions ) > 0 :
113141 return "AND {where_conditions_clause}" .format (
114- where_conditions_clause = " AND " .join (where_conditions ))
142+ where_conditions_clause = " AND " .join (where_conditions )
143+ )
115144 else :
116145 return ""
117146
118147 @staticmethod
119148 def create_custom_sql_conditions (custom_sql_conditions : List [str ]) -> str :
120149 if len (custom_sql_conditions ) > 0 :
121150 return "AND {custom_sql_conditions_clause}" .format (
122- custom_sql_conditions_clause = " AND " .join (custom_sql_conditions ))
151+ custom_sql_conditions_clause = " AND " .join (
152+ custom_sql_conditions
153+ )
154+ )
123155 else :
124156 return ''
125157
126158 @staticmethod
127159 def generate_params (conditions : dict ) -> dict :
128160 result = []
129161 for k , v in conditions .items ():
130- result .append ({
131- "name" : "@%s" % k ,
132- "value" : v
133- })
162+ result .append ({"name" : "@%s" % k , "value" : v })
134163
135164 return result
136165
137166 @staticmethod
138167 def check_visibility (item , throw_not_found_if_deleted ):
139168 if throw_not_found_if_deleted and item .get ('deleted' ) is not None :
140- raise exceptions .CosmosResourceNotFoundError (message = 'Deleted item' ,
141- status_code = 404 )
169+ raise exceptions .CosmosResourceNotFoundError (
170+ message = 'Deleted item' , status_code = 404
171+ )
142172 return item
143173
144174 @staticmethod
@@ -158,14 +188,32 @@ def attach_context(data: dict, event_context: EventContext):
158188 "session_id" : event_context .session_id ,
159189 }
160190
161- def create (self , data : dict , event_context : EventContext , mapper : Callable = None ):
191+ def create (
192+ self , data : dict , event_context : EventContext , mapper : Callable = None
193+ ):
162194 self .on_create (data , event_context )
163195 function_mapper = self .get_mapper_or_dict (mapper )
164196 self .attach_context (data , event_context )
165197 return function_mapper (self .container .create_item (body = data ))
166198
167- def find (self , id : str , event_context : EventContext , peeker : 'function' = None ,
168- visible_only = True , mapper : Callable = None ):
199+ def on_create (self , new_item_data : dict , event_context : EventContext ):
200+ if new_item_data .get ('id' ) is None :
201+ new_item_data ['id' ] = generate_uuid4 ()
202+
203+ new_item_data [
204+ self .partition_key_attribute
205+ ] = self .find_partition_key_value (event_context )
206+
207+ self .replace_empty_value_per_none (new_item_data )
208+
209+ def find (
210+ self ,
211+ id : str ,
212+ event_context : EventContext ,
213+ peeker : 'function' = None ,
214+ visible_only = True ,
215+ mapper : Callable = None ,
216+ ):
169217 partition_key_value = self .find_partition_key_value (event_context )
170218 found_item = self .container .read_item (id , partition_key_value )
171219 if peeker :
@@ -174,8 +222,17 @@ def find(self, id: str, event_context: EventContext, peeker: 'function' = None,
174222 function_mapper = self .get_mapper_or_dict (mapper )
175223 return function_mapper (self .check_visibility (found_item , visible_only ))
176224
177- def find_all (self , event_context : EventContext , conditions : dict = {}, custom_sql_conditions : List [str ] = [],
178- custom_params : dict = {}, max_count = None , offset = 0 , visible_only = True , mapper : Callable = None ):
225+ def find_all (
226+ self ,
227+ event_context : EventContext ,
228+ conditions : dict = {},
229+ custom_sql_conditions : List [str ] = [],
230+ custom_params : dict = {},
231+ max_count = None ,
232+ offset = 0 ,
233+ visible_only = True ,
234+ mapper : Callable = None ,
235+ ):
179236 partition_key_value = self .find_partition_key_value (event_context )
180237 max_count = self .get_page_size_or (max_count )
181238 params = [
@@ -185,40 +242,80 @@ def find_all(self, event_context: EventContext, conditions: dict = {}, custom_sq
185242 ]
186243 params .extend (self .generate_params (conditions ))
187244 params .extend (custom_params )
188- result = self .container .query_items (query = """
189- SELECT * FROM c WHERE c.{partition_key_attribute}=@partition_key_value
190- {conditions_clause} {visibility_condition} {custom_sql_conditions_clause} {order_clause}
245+ result = self .container .query_items (
246+ query = """
247+ SELECT * FROM c
248+ WHERE c.{partition_key_attribute}=@partition_key_value
249+ {conditions_clause}
250+ {visibility_condition}
251+ {custom_sql_conditions_clause}
252+ {order_clause}
191253 OFFSET @offset LIMIT @max_count
192- """ .format (partition_key_attribute = self .partition_key_attribute ,
193- visibility_condition = self .create_sql_condition_for_visibility (visible_only ),
194- conditions_clause = self .create_sql_where_conditions (conditions ),
195- custom_sql_conditions_clause = self .create_custom_sql_conditions (custom_sql_conditions ),
196- order_clause = self .create_sql_order_clause ()),
197- parameters = params ,
198- partition_key = partition_key_value ,
199- max_item_count = max_count )
254+ """ .format (
255+ partition_key_attribute = self .partition_key_attribute ,
256+ visibility_condition = self .create_sql_condition_for_visibility (
257+ visible_only
258+ ),
259+ conditions_clause = self .create_sql_where_conditions (conditions ),
260+ custom_sql_conditions_clause = self .create_custom_sql_conditions (
261+ custom_sql_conditions
262+ ),
263+ order_clause = self .create_sql_order_clause (),
264+ ),
265+ parameters = params ,
266+ partition_key = partition_key_value ,
267+ max_item_count = max_count ,
268+ )
200269
201270 function_mapper = self .get_mapper_or_dict (mapper )
202271 return list (map (function_mapper , result ))
203272
204- def partial_update (self , id : str , changes : dict , event_context : EventContext ,
205- peeker : 'function' = None , visible_only = True , mapper : Callable = None ):
206- item_data = self .find (id , event_context , peeker = peeker , visible_only = visible_only , mapper = dict )
273+ def partial_update (
274+ self ,
275+ id : str ,
276+ changes : dict ,
277+ event_context : EventContext ,
278+ peeker : 'function' = None ,
279+ visible_only = True ,
280+ mapper : Callable = None ,
281+ ):
282+ item_data = self .find (
283+ id ,
284+ event_context ,
285+ peeker = peeker ,
286+ visible_only = visible_only ,
287+ mapper = dict ,
288+ )
207289 item_data .update (changes )
208290 return self .update (id , item_data , event_context , mapper = mapper )
209291
210- def update (self , id : str , item_data : dict , event_context : EventContext ,
211- mapper : Callable = None ):
292+ def update (
293+ self ,
294+ id : str ,
295+ item_data : dict ,
296+ event_context : EventContext ,
297+ mapper : Callable = None ,
298+ ):
212299 self .on_update (item_data , event_context )
213300 function_mapper = self .get_mapper_or_dict (mapper )
214301 self .attach_context (item_data , event_context )
215302 return function_mapper (self .container .replace_item (id , body = item_data ))
216303
217- def delete (self , id : str , event_context : EventContext ,
218- peeker : 'function' = None , mapper : Callable = None ):
219- return self .partial_update (id , {
220- 'deleted' : generate_uuid4 ()
221- }, event_context , peeker = peeker , visible_only = True , mapper = mapper )
304+ def delete (
305+ self ,
306+ id : str ,
307+ event_context : EventContext ,
308+ peeker : 'function' = None ,
309+ mapper : Callable = None ,
310+ ):
311+ return self .partial_update (
312+ id ,
313+ {'deleted' : generate_uuid4 ()},
314+ event_context ,
315+ peeker = peeker ,
316+ visible_only = True ,
317+ mapper = mapper ,
318+ )
222319
223320 def delete_permanently (self , id : str , event_context : EventContext ) -> None :
224321 partition_key_value = self .find_partition_key_value (event_context )
@@ -235,14 +332,6 @@ def get_page_size_or(self, custom_page_size: int) -> int:
235332 # or any other repository for the settings
236333 return custom_page_size or 100
237334
238- def on_create (self , new_item_data : dict , event_context : EventContext ):
239- if new_item_data .get ('id' ) is None :
240- new_item_data ['id' ] = generate_uuid4 ()
241-
242- new_item_data [self .partition_key_attribute ] = self .find_partition_key_value (event_context )
243-
244- self .replace_empty_value_per_none (new_item_data )
245-
246335 def on_update (self , update_item_data : dict , event_context : EventContext ):
247336 pass
248337
@@ -276,9 +365,12 @@ def delete(self, id):
276365 event_ctx = self .create_event_context ("delete" )
277366 self .repository .delete (id , event_ctx )
278367
279- def create_event_context (self , action : str = None , description : str = None ):
280- return EventContext (self .repository .container .id , action ,
281- description = description )
368+ def create_event_context (
369+ self , action : str = None , description : str = None
370+ ):
371+ return EventContext (
372+ self .repository .container .id , action , description = description
373+ )
282374
283375
284376class CustomError (HTTPException ):
@@ -324,10 +416,7 @@ def get_current_month() -> int:
324416 return datetime .now ().month
325417
326418
327- def get_date_range_of_month (
328- year : int ,
329- month : int
330- ) -> Dict [str , str ]:
419+ def get_date_range_of_month (year : int , month : int ) -> Dict [str , str ]:
331420 first_day_of_month = 1
332421 start_date = datetime (year = year , month = month , day = first_day_of_month )
333422
@@ -339,10 +428,10 @@ def get_date_range_of_month(
339428 hour = 23 ,
340429 minute = 59 ,
341430 second = 59 ,
342- microsecond = 999999
431+ microsecond = 999999 ,
343432 )
344433
345434 return {
346435 'start_date' : datetime_str (start_date ),
347- 'end_date' : datetime_str (end_date )
436+ 'end_date' : datetime_str (end_date ),
348437 }
0 commit comments