Skip to content

Commit b3062bf

Browse files
authored
Merge pull request piccolo-orm#92 from piccolo-orm/array_filtering
added support for `any` and `all` for array filtering
2 parents 90ac5f6 + a558d2f commit b3062bf

File tree

4 files changed

+111
-0
lines changed

4 files changed

+111
-0
lines changed

docs/src/piccolo/schema/column_types.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,15 @@ Accessing individual elements
221221
=============================
222222

223223
.. automethod:: Array.__getitem__
224+
225+
===
226+
any
227+
===
228+
229+
.. automethod:: Array.any
230+
231+
===
232+
all
233+
===
234+
235+
.. automethod:: Array.all

piccolo/columns/column_types.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
OnDelete,
1313
OnUpdate,
1414
)
15+
from piccolo.columns.combination import Where
1516
from piccolo.columns.defaults.date import DateArg, DateCustom, DateNow
1617
from piccolo.columns.defaults.interval import IntervalArg, IntervalCustom
1718
from piccolo.columns.defaults.time import TimeArg, TimeCustom, TimeNow
@@ -26,6 +27,7 @@
2627
TimestamptzNow,
2728
)
2829
from piccolo.columns.defaults.uuid import UUID4, UUIDArg
30+
from piccolo.columns.operators.comparison import ArrayAll, ArrayAny
2931
from piccolo.columns.operators.string import ConcatPostgres, ConcatSQLite
3032
from piccolo.columns.reference import LazyTableReference
3133
from piccolo.querystring import QueryString, Unquoted
@@ -1416,6 +1418,12 @@ def __getitem__(self, value: int) -> Array:
14161418
14171419
14181420
"""
1421+
engine_type = self._meta.table._meta.db.engine_type
1422+
if engine_type != "postgres":
1423+
raise ValueError(
1424+
"Only Postgres supports array indexing currently."
1425+
)
1426+
14191427
if isinstance(value, int):
14201428
if value < 0:
14211429
raise ValueError("Only positive integers are allowed.")
@@ -1436,3 +1444,39 @@ def get_select_string(self, engine_type: str, just_alias=False) -> str:
14361444
return f"{select_string}[{self.index}]"
14371445
else:
14381446
return select_string
1447+
1448+
def any(self, value: t.Any) -> Where:
1449+
"""
1450+
Check if any of the items in the array match the given value.
1451+
1452+
.. code-block:: python
1453+
1454+
>>> Ticket.select().where(Ticket.seat_numbers.any(510)).run_sync()
1455+
1456+
"""
1457+
engine_type = self._meta.table._meta.db.engine_type
1458+
1459+
if engine_type == "postgres":
1460+
return Where(column=self, value=value, operator=ArrayAny)
1461+
elif engine_type == "sqlite":
1462+
return self.like(f"%{value}%")
1463+
else:
1464+
raise ValueError("Unrecognised engine type")
1465+
1466+
def all(self, value: t.Any) -> Where:
1467+
"""
1468+
Check if all of the items in the array match the given value.
1469+
1470+
.. code-block:: python
1471+
1472+
>>> Ticket.select().where(Ticket.seat_numbers.all(510)).run_sync()
1473+
1474+
"""
1475+
engine_type = self._meta.table._meta.db.engine_type
1476+
1477+
if engine_type == "postgres":
1478+
return Where(column=self, value=value, operator=ArrayAll)
1479+
elif engine_type == "sqlite":
1480+
raise ValueError("Unsupported by SQLite")
1481+
else:
1482+
raise ValueError("Unrecognised engine type")

piccolo/columns/operators/comparison.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,11 @@ class LessThan(ComparisonOperator):
5656

5757
class LessEqualThan(ComparisonOperator):
5858
template = "{name} <= {value}"
59+
60+
61+
class ArrayAny(ComparisonOperator):
62+
template = "{value} = ANY ({name})"
63+
64+
65+
class ArrayAll(ComparisonOperator):
66+
template = "{value} = ALL ({name})"

tests/columns/test_array.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,50 @@ def test_index(self):
3939
self.assertEqual(
4040
MyTable.select(MyTable.value[0]).first().run_sync(), {"value": 1}
4141
)
42+
43+
@postgres_only
44+
def test_all(self):
45+
"""
46+
Make sure rows can be retrieved where all items in an array match a
47+
given value.
48+
"""
49+
MyTable(value=[1, 1, 1]).save().run_sync()
50+
51+
self.assertEqual(
52+
MyTable.select(MyTable.value)
53+
.where(MyTable.value.all(1))
54+
.first()
55+
.run_sync(),
56+
{"value": [1, 1, 1]},
57+
)
58+
59+
self.assertEqual(
60+
MyTable.select(MyTable.value)
61+
.where(MyTable.value.all(0))
62+
.first()
63+
.run_sync(),
64+
None,
65+
)
66+
67+
def test_any(self):
68+
"""
69+
Make sure rows can be retrieved where any items in an array match a
70+
given value.
71+
"""
72+
MyTable(value=[1, 2, 3]).save().run_sync()
73+
74+
self.assertEqual(
75+
MyTable.select(MyTable.value)
76+
.where(MyTable.value.any(1))
77+
.first()
78+
.run_sync(),
79+
{"value": [1, 2, 3]},
80+
)
81+
82+
self.assertEqual(
83+
MyTable.select(MyTable.value)
84+
.where(MyTable.value.any(0))
85+
.first()
86+
.run_sync(),
87+
None,
88+
)

0 commit comments

Comments
 (0)