@@ -26,24 +26,35 @@ def __init__(self, client, db_id: str, logger=None): # pragma: no cover
26
26
def from_flask_config (cls , app : Flask ):
27
27
db_uri = app .config .get ('COSMOS_DATABASE_URI' )
28
28
if db_uri is None :
29
- app .logger .warn ("COSMOS_DATABASE_URI was not found. Looking for alternative variables." )
29
+ app .logger .warn (
30
+ "COSMOS_DATABASE_URI was not found. Looking for alternative variables."
31
+ )
30
32
account_uri = app .config .get ('DATABASE_ACCOUNT_URI' )
31
33
if account_uri is None :
32
- raise EnvironmentError ("DATABASE_ACCOUNT_URI is not defined in the environment" )
34
+ raise EnvironmentError (
35
+ "DATABASE_ACCOUNT_URI is not defined in the environment"
36
+ )
33
37
34
38
master_key = app .config .get ('DATABASE_MASTER_KEY' )
35
39
if master_key is None :
36
- raise EnvironmentError ("DATABASE_MASTER_KEY is not defined in the environment" )
37
-
38
- client = cosmos_client .CosmosClient (account_uri , {'masterKey' : master_key },
39
- user_agent = "TimeTrackerAPI" ,
40
- user_agent_overwrite = True )
40
+ raise EnvironmentError (
41
+ "DATABASE_MASTER_KEY is not defined in the environment"
42
+ )
43
+
44
+ client = cosmos_client .CosmosClient (
45
+ account_uri ,
46
+ {'masterKey' : master_key },
47
+ user_agent = "TimeTrackerAPI" ,
48
+ user_agent_overwrite = True ,
49
+ )
41
50
else :
42
51
client = cosmos_client .CosmosClient .from_connection_string (db_uri )
43
52
44
53
db_id = app .config .get ('DATABASE_NAME' )
45
54
if db_id is None :
46
- raise EnvironmentError ("DATABASE_NAME is not defined in the environment" )
55
+ raise EnvironmentError (
56
+ "DATABASE_NAME is not defined in the environment"
57
+ )
47
58
48
59
return cls (client , db_id , logger = app .logger )
49
60
@@ -60,7 +71,7 @@ def delete_container(self, container_id: str):
60
71
cosmos_helper : CosmosDBFacade = None
61
72
62
73
63
- class CosmosDBModel () :
74
+ class CosmosDBModel :
64
75
def __init__ (self , data ):
65
76
names = set ([f .name for f in dataclasses .fields (self )])
66
77
for k , v in data .items ():
@@ -73,72 +84,91 @@ def partition_key_attribute(pk: PartitionKey) -> str:
73
84
74
85
75
86
class CosmosDBRepository :
76
- def __init__ (self , container_id : str ,
77
- partition_key_attribute : str ,
78
- mapper : Callable = None ,
79
- order_fields : list = [],
80
- custom_cosmos_helper : CosmosDBFacade = None ):
87
+ def __init__ (
88
+ self ,
89
+ container_id : str ,
90
+ partition_key_attribute : str ,
91
+ mapper : Callable = None ,
92
+ order_fields : list = [],
93
+ custom_cosmos_helper : CosmosDBFacade = None ,
94
+ ):
81
95
global cosmos_helper
82
96
self .cosmos_helper = custom_cosmos_helper or cosmos_helper
83
97
if self .cosmos_helper is None : # pragma: no cover
84
98
raise ValueError ("The cosmos_db module has not been initialized!" )
85
99
self .mapper = mapper
86
100
self .order_fields = order_fields
87
- self .container : ContainerProxy = self .cosmos_helper .db .get_container_client (container_id )
88
- self .partition_key_attribute : str = partition_key_attribute
101
+ self .container : ContainerProxy = self .cosmos_helper .db .get_container_client (
102
+ container_id
103
+ )
104
+ self .partition_key_attribute = partition_key_attribute
89
105
90
106
@classmethod
91
- def from_definition (cls , container_definition : dict ,
92
- mapper : Callable = None ,
93
- custom_cosmos_helper : CosmosDBFacade = None ):
94
- pk_attrib = partition_key_attribute (container_definition ['partition_key' ])
95
- return cls (container_definition ['id' ], pk_attrib ,
96
- mapper = mapper ,
97
- custom_cosmos_helper = custom_cosmos_helper )
107
+ def from_definition (
108
+ cls ,
109
+ container_definition : dict ,
110
+ mapper : Callable = None ,
111
+ custom_cosmos_helper : CosmosDBFacade = None ,
112
+ ):
113
+ pk_attrib = partition_key_attribute (
114
+ container_definition ['partition_key' ]
115
+ )
116
+ return cls (
117
+ container_definition ['id' ],
118
+ pk_attrib ,
119
+ mapper = mapper ,
120
+ custom_cosmos_helper = custom_cosmos_helper ,
121
+ )
98
122
99
123
@staticmethod
100
- def create_sql_condition_for_visibility (visible_only : bool , container_name = 'c' ) -> str :
124
+ def create_sql_condition_for_visibility (
125
+ visible_only : bool , container_name = 'c'
126
+ ) -> str :
101
127
if visible_only :
102
128
# We are considering that `deleted == null` is not a choice
103
129
return 'AND NOT IS_DEFINED(%s.deleted)' % container_name
104
130
return ''
105
131
106
132
@staticmethod
107
- def create_sql_where_conditions (conditions : dict , container_name = 'c' ) -> str :
133
+ def create_sql_where_conditions (
134
+ conditions : dict , container_name = 'c'
135
+ ) -> str :
108
136
where_conditions = []
109
137
for k in conditions .keys ():
110
138
where_conditions .append (f'{ container_name } .{ k } = @{ k } ' )
111
139
112
140
if len (where_conditions ) > 0 :
113
141
return "AND {where_conditions_clause}" .format (
114
- where_conditions_clause = " AND " .join (where_conditions ))
142
+ where_conditions_clause = " AND " .join (where_conditions )
143
+ )
115
144
else :
116
145
return ""
117
146
118
147
@staticmethod
119
148
def create_custom_sql_conditions (custom_sql_conditions : List [str ]) -> str :
120
149
if len (custom_sql_conditions ) > 0 :
121
150
return "AND {custom_sql_conditions_clause}" .format (
122
- custom_sql_conditions_clause = " AND " .join (custom_sql_conditions ))
151
+ custom_sql_conditions_clause = " AND " .join (
152
+ custom_sql_conditions
153
+ )
154
+ )
123
155
else :
124
156
return ''
125
157
126
158
@staticmethod
127
159
def generate_params (conditions : dict ) -> dict :
128
160
result = []
129
161
for k , v in conditions .items ():
130
- result .append ({
131
- "name" : "@%s" % k ,
132
- "value" : v
133
- })
162
+ result .append ({"name" : "@%s" % k , "value" : v })
134
163
135
164
return result
136
165
137
166
@staticmethod
138
167
def check_visibility (item , throw_not_found_if_deleted ):
139
168
if throw_not_found_if_deleted and item .get ('deleted' ) is not None :
140
- raise exceptions .CosmosResourceNotFoundError (message = 'Deleted item' ,
141
- status_code = 404 )
169
+ raise exceptions .CosmosResourceNotFoundError (
170
+ message = 'Deleted item' , status_code = 404
171
+ )
142
172
return item
143
173
144
174
@staticmethod
@@ -158,14 +188,32 @@ def attach_context(data: dict, event_context: EventContext):
158
188
"session_id" : event_context .session_id ,
159
189
}
160
190
161
- def create (self , data : dict , event_context : EventContext , mapper : Callable = None ):
191
+ def create (
192
+ self , data : dict , event_context : EventContext , mapper : Callable = None
193
+ ):
162
194
self .on_create (data , event_context )
163
195
function_mapper = self .get_mapper_or_dict (mapper )
164
196
self .attach_context (data , event_context )
165
197
return function_mapper (self .container .create_item (body = data ))
166
198
167
- def find (self , id : str , event_context : EventContext , peeker : 'function' = None ,
168
- visible_only = True , mapper : Callable = None ):
199
+ def on_create (self , new_item_data : dict , event_context : EventContext ):
200
+ if new_item_data .get ('id' ) is None :
201
+ new_item_data ['id' ] = generate_uuid4 ()
202
+
203
+ new_item_data [
204
+ self .partition_key_attribute
205
+ ] = self .find_partition_key_value (event_context )
206
+
207
+ self .replace_empty_value_per_none (new_item_data )
208
+
209
+ def find (
210
+ self ,
211
+ id : str ,
212
+ event_context : EventContext ,
213
+ peeker : 'function' = None ,
214
+ visible_only = True ,
215
+ mapper : Callable = None ,
216
+ ):
169
217
partition_key_value = self .find_partition_key_value (event_context )
170
218
found_item = self .container .read_item (id , partition_key_value )
171
219
if peeker :
@@ -174,8 +222,17 @@ def find(self, id: str, event_context: EventContext, peeker: 'function' = None,
174
222
function_mapper = self .get_mapper_or_dict (mapper )
175
223
return function_mapper (self .check_visibility (found_item , visible_only ))
176
224
177
- def find_all (self , event_context : EventContext , conditions : dict = {}, custom_sql_conditions : List [str ] = [],
178
- custom_params : dict = {}, max_count = None , offset = 0 , visible_only = True , mapper : Callable = None ):
225
+ def find_all (
226
+ self ,
227
+ event_context : EventContext ,
228
+ conditions : dict = {},
229
+ custom_sql_conditions : List [str ] = [],
230
+ custom_params : dict = {},
231
+ max_count = None ,
232
+ offset = 0 ,
233
+ visible_only = True ,
234
+ mapper : Callable = None ,
235
+ ):
179
236
partition_key_value = self .find_partition_key_value (event_context )
180
237
max_count = self .get_page_size_or (max_count )
181
238
params = [
@@ -185,40 +242,80 @@ def find_all(self, event_context: EventContext, conditions: dict = {}, custom_sq
185
242
]
186
243
params .extend (self .generate_params (conditions ))
187
244
params .extend (custom_params )
188
- result = self .container .query_items (query = """
189
- SELECT * FROM c WHERE c.{partition_key_attribute}=@partition_key_value
190
- {conditions_clause} {visibility_condition} {custom_sql_conditions_clause} {order_clause}
245
+ result = self .container .query_items (
246
+ query = """
247
+ SELECT * FROM c
248
+ WHERE c.{partition_key_attribute}=@partition_key_value
249
+ {conditions_clause}
250
+ {visibility_condition}
251
+ {custom_sql_conditions_clause}
252
+ {order_clause}
191
253
OFFSET @offset LIMIT @max_count
192
- """ .format (partition_key_attribute = self .partition_key_attribute ,
193
- visibility_condition = self .create_sql_condition_for_visibility (visible_only ),
194
- conditions_clause = self .create_sql_where_conditions (conditions ),
195
- custom_sql_conditions_clause = self .create_custom_sql_conditions (custom_sql_conditions ),
196
- order_clause = self .create_sql_order_clause ()),
197
- parameters = params ,
198
- partition_key = partition_key_value ,
199
- max_item_count = max_count )
254
+ """ .format (
255
+ partition_key_attribute = self .partition_key_attribute ,
256
+ visibility_condition = self .create_sql_condition_for_visibility (
257
+ visible_only
258
+ ),
259
+ conditions_clause = self .create_sql_where_conditions (conditions ),
260
+ custom_sql_conditions_clause = self .create_custom_sql_conditions (
261
+ custom_sql_conditions
262
+ ),
263
+ order_clause = self .create_sql_order_clause (),
264
+ ),
265
+ parameters = params ,
266
+ partition_key = partition_key_value ,
267
+ max_item_count = max_count ,
268
+ )
200
269
201
270
function_mapper = self .get_mapper_or_dict (mapper )
202
271
return list (map (function_mapper , result ))
203
272
204
- def partial_update (self , id : str , changes : dict , event_context : EventContext ,
205
- peeker : 'function' = None , visible_only = True , mapper : Callable = None ):
206
- item_data = self .find (id , event_context , peeker = peeker , visible_only = visible_only , mapper = dict )
273
+ def partial_update (
274
+ self ,
275
+ id : str ,
276
+ changes : dict ,
277
+ event_context : EventContext ,
278
+ peeker : 'function' = None ,
279
+ visible_only = True ,
280
+ mapper : Callable = None ,
281
+ ):
282
+ item_data = self .find (
283
+ id ,
284
+ event_context ,
285
+ peeker = peeker ,
286
+ visible_only = visible_only ,
287
+ mapper = dict ,
288
+ )
207
289
item_data .update (changes )
208
290
return self .update (id , item_data , event_context , mapper = mapper )
209
291
210
- def update (self , id : str , item_data : dict , event_context : EventContext ,
211
- mapper : Callable = None ):
292
+ def update (
293
+ self ,
294
+ id : str ,
295
+ item_data : dict ,
296
+ event_context : EventContext ,
297
+ mapper : Callable = None ,
298
+ ):
212
299
self .on_update (item_data , event_context )
213
300
function_mapper = self .get_mapper_or_dict (mapper )
214
301
self .attach_context (item_data , event_context )
215
302
return function_mapper (self .container .replace_item (id , body = item_data ))
216
303
217
- def delete (self , id : str , event_context : EventContext ,
218
- peeker : 'function' = None , mapper : Callable = None ):
219
- return self .partial_update (id , {
220
- 'deleted' : generate_uuid4 ()
221
- }, event_context , peeker = peeker , visible_only = True , mapper = mapper )
304
+ def delete (
305
+ self ,
306
+ id : str ,
307
+ event_context : EventContext ,
308
+ peeker : 'function' = None ,
309
+ mapper : Callable = None ,
310
+ ):
311
+ return self .partial_update (
312
+ id ,
313
+ {'deleted' : generate_uuid4 ()},
314
+ event_context ,
315
+ peeker = peeker ,
316
+ visible_only = True ,
317
+ mapper = mapper ,
318
+ )
222
319
223
320
def delete_permanently (self , id : str , event_context : EventContext ) -> None :
224
321
partition_key_value = self .find_partition_key_value (event_context )
@@ -235,14 +332,6 @@ def get_page_size_or(self, custom_page_size: int) -> int:
235
332
# or any other repository for the settings
236
333
return custom_page_size or 100
237
334
238
- def on_create (self , new_item_data : dict , event_context : EventContext ):
239
- if new_item_data .get ('id' ) is None :
240
- new_item_data ['id' ] = generate_uuid4 ()
241
-
242
- new_item_data [self .partition_key_attribute ] = self .find_partition_key_value (event_context )
243
-
244
- self .replace_empty_value_per_none (new_item_data )
245
-
246
335
def on_update (self , update_item_data : dict , event_context : EventContext ):
247
336
pass
248
337
@@ -276,9 +365,12 @@ def delete(self, id):
276
365
event_ctx = self .create_event_context ("delete" )
277
366
self .repository .delete (id , event_ctx )
278
367
279
- def create_event_context (self , action : str = None , description : str = None ):
280
- return EventContext (self .repository .container .id , action ,
281
- description = description )
368
+ def create_event_context (
369
+ self , action : str = None , description : str = None
370
+ ):
371
+ return EventContext (
372
+ self .repository .container .id , action , description = description
373
+ )
282
374
283
375
284
376
class CustomError (HTTPException ):
@@ -324,10 +416,7 @@ def get_current_month() -> int:
324
416
return datetime .now ().month
325
417
326
418
327
- def get_date_range_of_month (
328
- year : int ,
329
- month : int
330
- ) -> Dict [str , str ]:
419
+ def get_date_range_of_month (year : int , month : int ) -> Dict [str , str ]:
331
420
first_day_of_month = 1
332
421
start_date = datetime (year = year , month = month , day = first_day_of_month )
333
422
@@ -339,10 +428,10 @@ def get_date_range_of_month(
339
428
hour = 23 ,
340
429
minute = 59 ,
341
430
second = 59 ,
342
- microsecond = 999999
431
+ microsecond = 999999 ,
343
432
)
344
433
345
434
return {
346
435
'start_date' : datetime_str (start_date ),
347
- 'end_date' : datetime_str (end_date )
436
+ 'end_date' : datetime_str (end_date ),
348
437
}
0 commit comments