Skip to content

Commit de78c07

Browse files
authored
Merge pull request piccolo-orm#91 from piccolo-orm/array_column_type
Array column type
2 parents c755075 + c93a521 commit de78c07

File tree

7 files changed

+205
-7
lines changed

7 files changed

+205
-7
lines changed

docs/src/piccolo/schema/column_types.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,20 @@ a subset of the JSON data, and for filtering in a where clause.
204204
>>> Booking.data.arrow('name') == '"Alison"'
205205
>>> ).run_sync()
206206
[{'id': 1}]
207+
208+
-------------------------------------------------------------------------------
209+
210+
*****
211+
Array
212+
*****
213+
214+
Arrays of data can be stored, which can be useful when you want store lots of
215+
values without using foreign keys.
216+
217+
.. autoclass:: Array
218+
219+
=============================
220+
Accessing individual elements
221+
=============================
222+
223+
.. automethod:: Array.__getitem__

piccolo/columns/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,11 @@ def get_sql_value(self, value: t.Any) -> t.Any:
475475
output = f"'{value.isoformat().replace('T', '')}'"
476476
elif isinstance(value, bytes):
477477
output = f"'{value.hex()}'"
478+
elif isinstance(value, list):
479+
# Convert to the array syntax.
480+
output = (
481+
"'{" + ", ".join([self.get_sql_value(i) for i in value]) + "}'"
482+
)
478483
else:
479484
output = value
480485

piccolo/columns/column_types.py

Lines changed: 97 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_querystring(
5252
value: t.Union[str, Varchar, Text],
5353
engine_type: str,
5454
reverse=False,
55-
):
55+
) -> QueryString:
5656
Concat = ConcatPostgres if engine_type == "postgres" else ConcatSQLite
5757

5858
if isinstance(value, (Varchar, Text)):
@@ -106,7 +106,7 @@ def get_querystring(
106106
operator: str,
107107
value: t.Union[int, float, Integer],
108108
reverse=False,
109-
):
109+
) -> QueryString:
110110
if isinstance(value, Integer):
111111
column: Integer = value
112112
if len(column._meta.call_chain) > 0:
@@ -1268,8 +1268,9 @@ def arrow(self, key: str) -> JSONB:
12681268
Allows part of the JSON structure to be returned - for example,
12691269
for {"a": 1}, and a key value of "a", then 1 will be returned.
12701270
"""
1271-
self.json_operator = f"-> '{key}'"
1272-
return self
1271+
instance = t.cast(JSONB, self.copy())
1272+
instance.json_operator = f"-> '{key}'"
1273+
return instance
12731274

12741275
def get_select_string(self, engine_type: str, just_alias=False) -> str:
12751276
select_string = self._meta.get_full_name(just_alias=just_alias)
@@ -1343,3 +1344,95 @@ class Blob(Bytea):
13431344
"""
13441345

13451346
pass
1347+
1348+
1349+
###############################################################################
1350+
1351+
1352+
class Array(Column):
1353+
"""
1354+
Used for storing lists of data.
1355+
1356+
**Example**
1357+
1358+
.. code-block:: python
1359+
1360+
class Ticket(Table):
1361+
seat_numbers = Array(base_column=Integer())
1362+
1363+
# Create
1364+
>>> Ticket(seat_numbers=[34, 35, 36]).save().run_sync()
1365+
1366+
# Query
1367+
>>> Ticket.select(Ticket.seat_numbers).run_sync()
1368+
{'seat_numbers': [34, 35, 36]}
1369+
1370+
"""
1371+
1372+
value_type = list
1373+
1374+
def __init__(
1375+
self,
1376+
base_column: Column,
1377+
default: t.Union[t.List, t.Callable[[], t.List], None] = list,
1378+
**kwargs,
1379+
) -> None:
1380+
if isinstance(base_column, ForeignKey):
1381+
raise ValueError("Arrays of ForeignKeys aren't allowed.")
1382+
1383+
self._validate_default(default, (list, None))
1384+
1385+
self.base_column = base_column
1386+
self.default = default
1387+
self.index: t.Optional[int] = None
1388+
kwargs.update({"base_column": base_column, "default": default})
1389+
super().__init__(**kwargs)
1390+
1391+
@property
1392+
def column_type(self):
1393+
engine_type = self._meta.table._meta.db.engine_type
1394+
if engine_type == "postgres":
1395+
return f"{self.base_column.column_type}[]"
1396+
elif engine_type == "sqlite":
1397+
return "ARRAY"
1398+
raise Exception("Unrecognized engine type")
1399+
1400+
def __getitem__(self, value: int) -> Array:
1401+
"""
1402+
Allows queries which retrieve an item from the array. The index starts
1403+
with 0 for the first value. If you were to write the SQL by hand, the
1404+
first index would be 1 instead:
1405+
1406+
https://www.postgresql.org/docs/current/arrays.html
1407+
1408+
However, we keep the first index as 0 to fit better with Python.
1409+
1410+
For example:
1411+
1412+
.. code-block:: python
1413+
1414+
>>> Ticket.select(Ticket.seat_numbers[0]).first().run_sync
1415+
{'seat_numbers': 325}
1416+
1417+
1418+
"""
1419+
if isinstance(value, int):
1420+
if value < 0:
1421+
raise ValueError("Only positive integers are allowed.")
1422+
1423+
instance = t.cast(Array, self.copy())
1424+
1425+
# We deliberately add 1, as Postgres treats the first array element
1426+
# as index 1.
1427+
instance.index = value + 1
1428+
return instance
1429+
else:
1430+
raise ValueError("Only integers can be used for indexing.")
1431+
1432+
def get_select_string(self, engine_type: str, just_alias=False) -> str:
1433+
select_string = self._meta.get_full_name(just_alias=just_alias)
1434+
1435+
if isinstance(self.index, int):
1436+
return f"{select_string}[{self.index}]"
1437+
else:
1438+
return select_string

piccolo/engine/sqlite.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from piccolo.engine.exceptions import TransactionError
1515
from piccolo.query.base import Query
1616
from piccolo.querystring import QueryString
17+
18+
from piccolo.utils.encoding import dump_json, load_json
1719
from piccolo.utils.sync import run_sync
1820

1921
###############################################################################
@@ -39,14 +41,14 @@ def convert_uuid_in(value) -> str:
3941
return str(value)
4042

4143

42-
def convert_time_in(value) -> str:
44+
def convert_time_in(value: datetime.time) -> str:
4345
"""
4446
Converts the time value being passed into sqlite.
4547
"""
4648
return value.isoformat()
4749

4850

49-
def convert_date_in(value):
51+
def convert_date_in(value: datetime.date):
5052
"""
5153
Converts the date value being passed into sqlite.
5254
"""
@@ -64,13 +66,24 @@ def convert_datetime_in(value: datetime.datetime) -> str:
6466
return str(value)
6567

6668

67-
def convert_timedelta_in(value):
69+
def convert_timedelta_in(value: datetime.timedelta):
6870
"""
6971
Converts the timedelta value being passed into sqlite.
7072
"""
7173
return value.total_seconds()
7274

7375

76+
def convert_array_in(value: list):
77+
"""
78+
Converts a list value into a string.
79+
"""
80+
if len(value) > 0:
81+
if type(value[0]) not in [str, int, float]:
82+
raise ValueError("Can only serialise str, int and float.")
83+
84+
return dump_json(value)
85+
86+
7487
# Out
7588

7689

@@ -129,6 +142,14 @@ def convert_timestamptz_out(value: bytes) -> datetime.datetime:
129142
return datetime.datetime.fromisoformat(value.decode("utf8"))
130143

131144

145+
def convert_array_out(value: bytes) -> t.List:
146+
"""
147+
If the value if from an array column, deserialise the string back into a
148+
list.
149+
"""
150+
return load_json(value.decode("utf8"))
151+
152+
132153
sqlite3.register_converter("Numeric", convert_numeric_out)
133154
sqlite3.register_converter("Integer", convert_int_out)
134155
sqlite3.register_converter("UUID", convert_uuid_out)
@@ -137,13 +158,15 @@ def convert_timestamptz_out(value: bytes) -> datetime.datetime:
137158
sqlite3.register_converter("Seconds", convert_seconds_out)
138159
sqlite3.register_converter("Boolean", convert_boolean_out)
139160
sqlite3.register_converter("Timestamptz", convert_timestamptz_out)
161+
sqlite3.register_converter("Array", convert_array_out)
140162

141163
sqlite3.register_adapter(Decimal, convert_numeric_in)
142164
sqlite3.register_adapter(uuid.UUID, convert_uuid_in)
143165
sqlite3.register_adapter(datetime.time, convert_time_in)
144166
sqlite3.register_adapter(datetime.date, convert_date_in)
145167
sqlite3.register_adapter(datetime.datetime, convert_datetime_in)
146168
sqlite3.register_adapter(datetime.timedelta, convert_timedelta_in)
169+
sqlite3.register_adapter(list, convert_array_in)
147170

148171
###############################################################################
149172

piccolo/utils/encoding.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,10 @@ def dump_json(data: t.Any) -> str:
1616
return orjson.dumps(data, default=str).decode("utf8")
1717
else:
1818
return json.dumps(data, default=str)
19+
20+
21+
def load_json(data: str) -> t.Any:
22+
if ORJSON:
23+
return orjson.loads(data)
24+
else:
25+
return json.loads(data)

tests/columns/test_array.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from unittest import TestCase
2+
3+
from piccolo.table import Table
4+
from piccolo.columns.column_types import Array, Integer
5+
from tests.base import postgres_only
6+
7+
8+
class MyTable(Table):
9+
value = Array(base_column=Integer())
10+
11+
12+
class TestArrayPostgres(TestCase):
13+
"""
14+
Make sure an Array column can be created.
15+
"""
16+
17+
def setUp(self):
18+
MyTable.create_table().run_sync()
19+
20+
def tearDown(self):
21+
MyTable.alter().drop_table().run_sync()
22+
23+
def test_storage(self):
24+
"""
25+
Make sure data can be stored and retrieved.
26+
"""
27+
MyTable(value=[1, 2, 3]).save().run_sync()
28+
29+
row = MyTable.objects().first().run_sync()
30+
self.assertEqual(row.value, [1, 2, 3])
31+
32+
@postgres_only
33+
def test_index(self):
34+
"""
35+
Indexes should allow individual array elements to be queried.
36+
"""
37+
MyTable(value=[1, 2, 3]).save().run_sync()
38+
39+
self.assertEqual(
40+
MyTable.select(MyTable.value[0]).first().run_sync(), {"value": 1}
41+
)

tests/utils/test_encoding.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from unittest import TestCase
2+
3+
from piccolo.utils.encoding import dump_json, load_json
4+
5+
6+
class TestEncodingDecoding(TestCase):
7+
def test_dump_load(self):
8+
"""
9+
Test dumping then loading an object.
10+
"""
11+
payload = {"a": [1, 2, 3]}
12+
self.assertEqual(load_json(dump_json(payload)), payload)

0 commit comments

Comments
 (0)