Skip to content

Commit db09579

Browse files
authored
make get_or_create work for complex where clauses (piccolo-orm#195)
* make get_or_create work for complex where clauses * update get_or_create docs * fix LGTM warning, and add `_was_created` attribute
1 parent 7d99e83 commit db09579

File tree

4 files changed

+143
-8
lines changed

4 files changed

+143
-8
lines changed

docs/src/piccolo/query_types/objects.rst

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,32 @@ or create a new one with the ``defaults`` arguments:
113113
Band.name == 'Pythonistas', defaults={'popularity': 100}
114114
).run_sync()
115115
116+
You can find out if an existing row was found, or if a new row was created:
117+
118+
.. code-block:: python
119+
120+
band = Band.objects.get_or_create(
121+
Band.name == 'Pythonistas'
122+
).run_sync()
123+
band._was_created # True if it was created, otherwise False if it was already in the db
124+
125+
Complex where clauses are supported, but only within reason. For example:
126+
127+
.. code-block:: python
128+
129+
# This works OK:
130+
band = Band.objects().get_or_create(
131+
(Band.name == 'Pythonistas') & (Band.popularity == 1000),
132+
).run_sync()
133+
134+
# This is problematic, as it's unclear what the name should be if we
135+
# need to create the row:
136+
band = Band.objects().get_or_create(
137+
(Band.name == 'Pythonistas') | (Band.name == 'Rustaceans'),
138+
defaults={'popularity': 100}
139+
).run_sync()
140+
141+
116142
Query clauses
117143
-------------
118144

piccolo/columns/combination.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import typing as t
44

5-
from piccolo.columns.operators.comparison import ComparisonOperator
5+
from piccolo.columns.operators.comparison import ComparisonOperator, Equal
66
from piccolo.custom_types import Combinable, Iterable
77
from piccolo.querystring import QueryString
88
from piccolo.utils.sql_values import convert_to_sql_value
@@ -44,6 +44,35 @@ def __str__(self):
4444
class And(Combination):
4545
operator = "AND"
4646

47+
def get_column_values(self) -> t.Dict[Column, t.Any]:
48+
"""
49+
This is used by `get_or_create` to know which values to assign if
50+
the row doesn't exist in the database.
51+
52+
For example, if we have:
53+
54+
(Band.name == 'Pythonistas') & (Band.popularity == 1000)
55+
56+
We will return {Band.name: 'Pythonistas', Band.popularity: 1000}.
57+
58+
If the operator is anything besides equals, we don't return it, for
59+
example:
60+
61+
(Band.name == 'Pythonistas') & (Band.popularity > 1000)
62+
63+
Returns {Band.name: 'Pythonistas'}
64+
65+
"""
66+
output = {}
67+
for combinable in (self.first, self.second):
68+
if isinstance(combinable, Where):
69+
if combinable.operator == Equal:
70+
output[combinable.column] = combinable.value
71+
elif isinstance(combinable, And):
72+
output.update(combinable.get_column_values())
73+
74+
return output
75+
4776

4877
class Or(Combination):
4978
operator = "OR"

piccolo/query/methods/objects.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import typing as t
44
from dataclasses import dataclass
55

6+
from piccolo.columns.combination import And, Where
67
from piccolo.custom_types import Combinable
78
from piccolo.engine.base import Batch
89
from piccolo.query.base import Query
@@ -32,14 +33,22 @@ class GetOrCreate:
3233
async def run(self):
3334
instance = await self.query.get(self.where).run()
3435
if instance:
36+
instance._was_created = False
3537
return instance
3638

3739
instance = self.query.table()
38-
setattr(
39-
instance,
40-
self.where.column._meta.name, # type: ignore
41-
self.where.value, # type: ignore
42-
)
40+
41+
# If it's a complex `where`, there can be several column values to
42+
# extract e.g. (Band.name == 'Pythonistas') & (Band.popularity == 1000)
43+
if isinstance(self.where, Where):
44+
setattr(
45+
instance,
46+
self.where.column._meta.name, # type: ignore
47+
self.where.value, # type: ignore
48+
)
49+
elif isinstance(self.where, And):
50+
for column, value in self.where.get_column_values().items():
51+
setattr(instance, column._meta.name, value)
4352

4453
for column, value in self.defaults.items():
4554
if isinstance(column, str):
@@ -48,6 +57,8 @@ async def run(self):
4857

4958
await instance.save().run()
5059

60+
instance._was_created = True
61+
5162
return instance
5263

5364
def __await__(self):

tests/table/test_objects.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def test_get(self):
6868
self.assertTrue(band.name == "Pythonistas")
6969

7070
def test_get_or_create(self):
71+
"""
72+
Make sure `get_or_create` works for simple where clauses.
73+
"""
74+
# When the row doesn't exist in the db:
7175
Band.objects().get_or_create(
7276
Band.name == "Pink Floyd", defaults={"popularity": 100}
7377
).run_sync()
@@ -76,10 +80,11 @@ def test_get_or_create(self):
7680
Band.objects().where(Band.name == "Pink Floyd").first().run_sync()
7781
)
7882

79-
self.assertTrue(isinstance(instance, Band))
83+
self.assertIsInstance(instance, Band)
8084
self.assertTrue(instance.name == "Pink Floyd")
8185
self.assertTrue(instance.popularity == 100)
8286

87+
# When the row already exists in the db:
8388
Band.objects().get_or_create(
8489
Band.name == "Pink Floyd", defaults={Band.popularity: 100}
8590
).run_sync()
@@ -88,6 +93,70 @@ def test_get_or_create(self):
8893
Band.objects().where(Band.name == "Pink Floyd").first().run_sync()
8994
)
9095

91-
self.assertTrue(isinstance(instance, Band))
96+
self.assertIsInstance(instance, Band)
9297
self.assertTrue(instance.name == "Pink Floyd")
9398
self.assertTrue(instance.popularity == 100)
99+
100+
def test_get_or_create_complex(self):
101+
"""
102+
Make sure `get_or_create` works with complex where clauses.
103+
"""
104+
self.insert_rows()
105+
106+
# When the row already exists in the db:
107+
instance = (
108+
Band.objects()
109+
.get_or_create(
110+
(Band.name == "Pythonistas") & (Band.popularity == 1000)
111+
)
112+
.run_sync()
113+
)
114+
self.assertIsInstance(instance, Band)
115+
self.assertEqual(instance._was_created, False)
116+
117+
# When the row doesn't exist in the db:
118+
instance = (
119+
Band.objects()
120+
.get_or_create(
121+
(Band.name == "Pythonistas2") & (Band.popularity == 2000)
122+
)
123+
.run_sync()
124+
)
125+
self.assertIsInstance(instance, Band)
126+
self.assertEqual(instance._was_created, True)
127+
128+
def test_get_or_create_very_complex(self):
129+
"""
130+
Make sure `get_or_create` works with very complex where clauses.
131+
"""
132+
self.insert_rows()
133+
134+
# When the row already exists in the db:
135+
instance = (
136+
Band.objects()
137+
.get_or_create(
138+
(Band.name == "Pythonistas")
139+
& (Band.popularity > 0)
140+
& (Band.popularity < 5000)
141+
)
142+
.run_sync()
143+
)
144+
self.assertIsInstance(instance, Band)
145+
self.assertEqual(instance._was_created, False)
146+
147+
# When the row doesn't exist in the db:
148+
instance = (
149+
Band.objects()
150+
.get_or_create(
151+
(Band.name == "Pythonistas2")
152+
& (Band.popularity > 10)
153+
& (Band.popularity < 5000)
154+
)
155+
.run_sync()
156+
)
157+
self.assertIsInstance(instance, Band)
158+
self.assertEqual(instance._was_created, True)
159+
160+
# The values in the > and < should be ignored, and the default should
161+
# be used for the column.
162+
self.assertEqual(instance.popularity, 0)

0 commit comments

Comments
 (0)