Skip to content

Commit 3adb219

Browse files
committed
changed Interval column implementation for SQLite, and added more tests
Rather than storing the timedelta as a string in SQLite, storing the number of seconds in a Numeric field instead - otherwise WHERE clauses don't work.
1 parent 1016665 commit 3adb219

File tree

4 files changed

+79
-29
lines changed

4 files changed

+79
-29
lines changed

piccolo/columns/column_types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,18 @@ def __init__(
651651
kwargs.update({"default": default})
652652
super().__init__(**kwargs)
653653

654+
@property
655+
def column_type(self):
656+
engine_type = self._meta.table._meta.db.engine_type
657+
if engine_type == "postgres":
658+
return "INTERVAL"
659+
elif engine_type == "sqlite":
660+
# We can't use 'INTERVAL' because the type affinity in SQLite would
661+
# make it an integer - but we need a numeric field.
662+
# https://sqlite.org/datatype3.html#determination_of_column_affinity
663+
return "SECONDS"
664+
raise Exception("Unrecognized engine type")
665+
654666

655667
###############################################################################
656668

piccolo/columns/defaults/interval.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,7 @@ def postgres(self):
5353

5454
@property
5555
def sqlite(self):
56-
value = self.get_sqlite_interval_string(
57-
attributes=[
58-
"weeks",
59-
"days",
60-
"hours",
61-
"minutes",
62-
"seconds",
63-
"milliseconds",
64-
"microseconds",
65-
]
66-
).replace("'", "")
67-
return f"'{value}'"
56+
return self.timedelta.total_seconds()
6857

6958
def python(self):
7059
return self.timedelta

piccolo/engine/sqlite.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import aiosqlite
1212
from aiosqlite import Cursor, Connection
1313

14-
from piccolo.columns.defaults.interval import IntervalCustom
1514
from piccolo.engine.base import Batch, Engine
1615
from piccolo.engine.exceptions import TransactionError
1716
from piccolo.query.base import Query
@@ -56,9 +55,7 @@ def convert_timedelta_in(value):
5655
"""
5756
Converts the timedelta value being passed into sqlite.
5857
"""
59-
return IntervalCustom.from_timedelta(instance=value).sqlite.replace(
60-
"'", ""
61-
)
58+
return value.total_seconds()
6259

6360

6461
def convert_numeric_out(value: bytes) -> Decimal:
@@ -93,26 +90,19 @@ def convert_time_out(value: bytes) -> datetime.time:
9390
return datetime.time.fromisoformat(value.decode("utf8"))
9491

9592

96-
def convert_interval_out(value: bytes) -> datetime.timedelta:
93+
def convert_seconds_out(value: bytes) -> datetime.timedelta:
9794
"""
98-
If the value is an interval, convert it to a timedelta instance.
95+
If the value is from a seconds column, convert it to a timedelta instance.
9996
"""
100-
unit_value_strings = [i.strip() for i in value.decode("utf8").split(",")]
101-
kwargs = {}
102-
103-
for unit_value_string in unit_value_strings:
104-
value, key = unit_value_string.split(" ")
105-
kwargs[key] = int(value)
106-
107-
return datetime.timedelta(**kwargs)
97+
return datetime.timedelta(seconds=float(value.decode("utf8")))
10898

10999

110100
sqlite3.register_converter("Numeric", convert_numeric_out)
111101
sqlite3.register_converter("Integer", convert_int_out)
112102
sqlite3.register_converter("UUID", convert_uuid_out)
113103
sqlite3.register_converter("Date", convert_date_out)
114104
sqlite3.register_converter("Time", convert_time_out)
115-
sqlite3.register_converter("Interval", convert_interval_out)
105+
sqlite3.register_converter("Seconds", convert_seconds_out)
116106

117107
sqlite3.register_adapter(Decimal, convert_numeric_in)
118108
sqlite3.register_adapter(uuid.UUID, convert_uuid_in)
@@ -343,6 +333,8 @@ def create_db(self, migrate=False):
343333
else:
344334
raise Exception(f"Database at {self.path} already exists")
345335
if migrate:
336+
# Commented out for now, as migrations for SQLite aren't as
337+
# well supported as Postgres.
346338
# from piccolo.commands.migration.forwards import (
347339
# ForwardsMigrationManager,
348340
# )

tests/columns/test_interval.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,69 @@ def tearDown(self):
2222
MyTable.alter().drop_table().run_sync()
2323

2424
def test_interval(self):
25+
# Test a range of different timedeltas
26+
intervals = [
27+
datetime.timedelta(weeks=1),
28+
datetime.timedelta(days=1),
29+
datetime.timedelta(hours=1),
30+
datetime.timedelta(minutes=1),
31+
datetime.timedelta(seconds=1),
32+
datetime.timedelta(milliseconds=1),
33+
datetime.timedelta(microseconds=1),
34+
]
35+
36+
for interval in intervals:
37+
# Make sure that saving it works OK.
38+
row = MyTable(interval=interval)
39+
row.save().run_sync()
40+
41+
# Make sure that fetching it back works OK.
42+
result = (
43+
MyTable.objects()
44+
.where(MyTable.id == row.id)
45+
.first()
46+
.run_sync()
47+
)
48+
self.assertEqual(result.interval, interval)
49+
50+
def test_interval_where_clause(self):
51+
"""
52+
Make sure a range of where clauses resolve correctly.
53+
"""
2554
interval = datetime.timedelta(hours=2)
2655
row = MyTable(interval=interval)
2756
row.save().run_sync()
2857

29-
result = MyTable.objects().first().run_sync()
30-
self.assertEqual(result.interval, interval)
58+
result = (
59+
MyTable.objects()
60+
.where(MyTable.interval < datetime.timedelta(hours=3))
61+
.first()
62+
.run_sync()
63+
)
64+
self.assertTrue(result is not None)
65+
66+
result = (
67+
MyTable.objects()
68+
.where(MyTable.interval < datetime.timedelta(days=1))
69+
.first()
70+
.run_sync()
71+
)
72+
self.assertTrue(result is not None)
73+
74+
result = (
75+
MyTable.objects()
76+
.where(MyTable.interval > datetime.timedelta(minutes=1))
77+
.first()
78+
.run_sync()
79+
)
80+
self.assertTrue(result is not None)
81+
82+
result = (
83+
MyTable.exists()
84+
.where(MyTable.interval == datetime.timedelta(hours=2))
85+
.run_sync()
86+
)
87+
self.assertTrue(result)
3188

3289

3390
class TestIntervalDefault(TestCase):

0 commit comments

Comments
 (0)