Skip to content

Commit 9672e17

Browse files
authored
Merge pull request piccolo-orm#119 from piccolo-orm/improved_where_clause
the `where` clause can now accept `Table` instances
2 parents a20a951 + 20ff274 commit 9672e17

File tree

2 files changed

+62
-5
lines changed

2 files changed

+62
-5
lines changed

piccolo/columns/combination.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from piccolo.custom_types import Combinable, Iterable
77
from piccolo.querystring import QueryString
88

9-
if t.TYPE_CHECKING: # pragma: no cover
10-
from piccolo.columns.base import Column # noqa
9+
if t.TYPE_CHECKING:
10+
from piccolo.columns.base import Column
1111

1212

1313
class CombinableMixin(object):
@@ -92,10 +92,43 @@ def __init__(
9292
omitted, vs None, which is a valid value for a where clause.
9393
"""
9494
self.column = column
95-
self.value = value
96-
self.values = values
95+
self.value = self.clean_value(value)
96+
97+
if values == UNDEFINED:
98+
self.values = values
99+
else:
100+
self.values = [self.clean_value(i) for i in values] # type: ignore
101+
97102
self.operator = operator
98103

104+
def clean_value(self, value: t.Any) -> t.Any:
105+
"""
106+
If a where clause contains a Table instance, we should convert that
107+
to a column reference. For example:
108+
109+
.. code-block:: python
110+
111+
manager = Manager.objects.where(
112+
Manager.name == 'Guido'
113+
).first().run_sync()
114+
115+
# The where clause should be:
116+
Band.select().where(Band.manager.id == guido.id).run_sync()
117+
# Or
118+
Band.select().where(Band.manager == guido.id).run_sync()
119+
120+
# If the object is passed in, i.e. `guido` instead of `guido.id`,
121+
# it should still work.
122+
Band.select().where(Band.manager == guido).run_sync()
123+
124+
"""
125+
from piccolo.table import Table
126+
127+
if isinstance(value, Table):
128+
return value.id
129+
else:
130+
return value
131+
99132
@property
100133
def values_querystring(self) -> QueryString:
101134
if isinstance(self.values, Undefined):

tests/table/test_select.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from piccolo.query.methods.select import Count
66

77
from ..base import DBTestCase, postgres_only, sqlite_only
8-
from ..example_app.tables import Band, Concert
8+
from ..example_app.tables import Band, Concert, Manager
99

1010

1111
class TestSelect(DBTestCase):
@@ -28,6 +28,30 @@ def test_query_some_columns(self):
2828

2929
self.assertDictEqual(response[0], {"name": "Pythonistas"})
3030

31+
def test_where_equals(self):
32+
self.insert_row()
33+
34+
manager = Manager.objects().first().run_sync()
35+
36+
# This is the recommended way of running these types of queries:
37+
response = (
38+
Band.select(Band.name)
39+
.where(Band.manager.id == manager.id)
40+
.run_sync()
41+
)
42+
self.assertEqual(response, [{"name": "Pythonistas"}])
43+
44+
# Other cases which should work:
45+
response = (
46+
Band.select(Band.name).where(Band.manager == manager).run_sync()
47+
)
48+
self.assertEqual(response, [{"name": "Pythonistas"}])
49+
50+
response = (
51+
Band.select(Band.name).where(Band.manager.id == manager).run_sync()
52+
)
53+
self.assertEqual(response, [{"name": "Pythonistas"}])
54+
3155
def test_where_like(self):
3256
self.insert_rows()
3357

0 commit comments

Comments
 (0)