Skip to content

Commit 7600fd5

Browse files
committed
added array support to sqlite
1 parent 27a942b commit 7600fd5

File tree

5 files changed

+51
-7
lines changed

5 files changed

+51
-7
lines changed

piccolo/columns/column_types.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1366,4 +1366,9 @@ def __init__(
13661366

13671367
@property
13681368
def column_type(self):
1369-
return f"{self.base_column.column_type}[]"
1369+
engine_type = self._meta.table._meta.db.engine_type
1370+
if engine_type == "postgres":
1371+
return f"{self.base_column.column_type}[]"
1372+
elif engine_type == "sqlite":
1373+
return "ARRAY"
1374+
raise Exception("Unrecognized engine type")

piccolo/engine/sqlite.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from piccolo.engine.exceptions import TransactionError
1515
from piccolo.query.base import Query
1616
from piccolo.querystring import QueryString
17+
18+
from piccolo.utils.encoding import dump_json, load_json
1719
from piccolo.utils.sync import run_sync
1820

1921
###############################################################################
@@ -39,14 +41,14 @@ def convert_uuid_in(value) -> str:
3941
return str(value)
4042

4143

42-
def convert_time_in(value) -> str:
44+
def convert_time_in(value: datetime.time) -> str:
4345
"""
4446
Converts the time value being passed into sqlite.
4547
"""
4648
return value.isoformat()
4749

4850

49-
def convert_date_in(value):
51+
def convert_date_in(value: datetime.date):
5052
"""
5153
Converts the date value being passed into sqlite.
5254
"""
@@ -64,13 +66,24 @@ def convert_datetime_in(value: datetime.datetime) -> str:
6466
return str(value)
6567

6668

67-
def convert_timedelta_in(value):
69+
def convert_timedelta_in(value: datetime.timedelta):
6870
"""
6971
Converts the timedelta value being passed into sqlite.
7072
"""
7173
return value.total_seconds()
7274

7375

76+
def convert_array_in(value: list):
77+
"""
78+
Converts a list value into a string.
79+
"""
80+
if len(value) > 0:
81+
if type(value[0]) not in [str, int, float]:
82+
raise ValueError("Can only serialise str, int and float.")
83+
84+
return dump_json(value)
85+
86+
7487
# Out
7588

7689

@@ -129,6 +142,14 @@ def convert_timestamptz_out(value: bytes) -> datetime.datetime:
129142
return datetime.datetime.fromisoformat(value.decode("utf8"))
130143

131144

145+
def convert_array_out(value: bytes) -> t.List:
146+
"""
147+
If the value if from an array column, deserialise the string back into a
148+
list.
149+
"""
150+
return load_json(value.decode("utf8"))
151+
152+
132153
sqlite3.register_converter("Numeric", convert_numeric_out)
133154
sqlite3.register_converter("Integer", convert_int_out)
134155
sqlite3.register_converter("UUID", convert_uuid_out)
@@ -137,13 +158,15 @@ def convert_timestamptz_out(value: bytes) -> datetime.datetime:
137158
sqlite3.register_converter("Seconds", convert_seconds_out)
138159
sqlite3.register_converter("Boolean", convert_boolean_out)
139160
sqlite3.register_converter("Timestamptz", convert_timestamptz_out)
161+
sqlite3.register_converter("Array", convert_array_out)
140162

141163
sqlite3.register_adapter(Decimal, convert_numeric_in)
142164
sqlite3.register_adapter(uuid.UUID, convert_uuid_in)
143165
sqlite3.register_adapter(datetime.time, convert_time_in)
144166
sqlite3.register_adapter(datetime.date, convert_date_in)
145167
sqlite3.register_adapter(datetime.datetime, convert_datetime_in)
146168
sqlite3.register_adapter(datetime.timedelta, convert_timedelta_in)
169+
sqlite3.register_adapter(list, convert_array_in)
147170

148171
###############################################################################
149172

piccolo/utils/encoding.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,10 @@ def dump_json(data: t.Any) -> str:
1616
return orjson.dumps(data, default=str).decode("utf8")
1717
else:
1818
return json.dumps(data, default=str)
19+
20+
21+
def load_json(data: str) -> t.Any:
22+
if ORJSON:
23+
return orjson.loads(data)
24+
else:
25+
return json.loads(data)

tests/columns/test_array.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,11 @@
33
from piccolo.table import Table
44
from piccolo.columns.column_types import Array, Integer
55

6-
from ..base import postgres_only
7-
86

97
class MyTable(Table):
108
value = Array(base_column=Integer())
119

1210

13-
@postgres_only
1411
class TestArrayPostgres(TestCase):
1512
"""
1613
Make sure an Array column can be created.

tests/utils/test_encoding.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from unittest import TestCase
2+
3+
from piccolo.utils.encoding import dump_json, load_json
4+
5+
6+
class TestEncodingDecoding(TestCase):
7+
def test_dump_load(self):
8+
"""
9+
Test dumping then loading an object.
10+
"""
11+
payload = {"a": [1, 2, 3]}
12+
self.assertEqual(load_json(dump_json(payload)), payload)

0 commit comments

Comments
 (0)