Skip to content

Commit 6daf271

Browse files
authored
Merge pull request piccolo-orm#12 from piccolo-orm/interval_column_type
added Interval column type
2 parents e1c1535 + 3adb219 commit 6daf271

File tree

10 files changed

+272
-5
lines changed

10 files changed

+272
-5
lines changed

docs/src/piccolo/schema/column_types.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ Date
117117

118118
.. autoclass:: Date
119119

120+
========
121+
Interval
122+
========
123+
124+
.. autoclass:: Interval
125+
120126
====
121127
Time
122128
====

piccolo/apps/playground/commands/run.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Varchar,
1313
ForeignKey,
1414
Integer,
15+
Interval,
1516
Numeric,
1617
Timestamp,
1718
UUID,
@@ -40,6 +41,7 @@ class Concert(Table):
4041
band_2 = ForeignKey(Band)
4142
venue = ForeignKey(Venue)
4243
starts = Timestamp()
44+
duration = Interval()
4345

4446

4547
class Ticket(Table):
@@ -91,6 +93,7 @@ def populate():
9193
band_2=rustaceans.id,
9294
venue=venue.id,
9395
starts=datetime.datetime.now() + datetime.timedelta(days=7),
96+
duration=datetime.timedelta(hours=2),
9497
)
9598
concert.save().run_sync()
9699

piccolo/columns/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Float,
55
ForeignKey,
66
Integer,
7+
Interval,
78
Numeric,
89
PrimaryKey,
910
Real,

piccolo/columns/column_types.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22
import copy
3-
from datetime import datetime, date, time
3+
from datetime import datetime, date, time, timedelta
44
import decimal
55
import typing as t
66
import uuid
@@ -9,6 +9,7 @@
99
from piccolo.columns.operators.string import ConcatPostgres, ConcatSQLite
1010
from piccolo.columns.defaults.date import DateArg, DateNow, DateCustom
1111
from piccolo.columns.defaults.time import TimeArg, TimeNow, TimeCustom
12+
from piccolo.columns.defaults.interval import IntervalArg, IntervalCustom
1213
from piccolo.columns.defaults.timestamp import (
1314
TimestampArg,
1415
TimestampNow,
@@ -612,6 +613,57 @@ def __init__(self, default: TimeArg = TimeNow(), **kwargs) -> None:
612613
super().__init__(**kwargs)
613614

614615

616+
class Interval(Column): # lgtm [py/missing-equals]
617+
"""
618+
Used for storing timedeltas. Uses the ``timedelta`` type for values.
619+
620+
**Example**
621+
622+
.. code-block:: python
623+
624+
from datetime import timedelta
625+
626+
class Concert(Table):
627+
duration = Interval()
628+
629+
# Create
630+
>>> Concert(
631+
>>> duration=timedelta(hours=2)
632+
>>> ).save().run_sync()
633+
634+
# Query
635+
>>> Concert.select(Concert.duration).run_sync()
636+
{'duration': datetime.timedelta(seconds=7200)}
637+
638+
"""
639+
640+
value_type = timedelta
641+
642+
def __init__(
643+
self, default: IntervalArg = IntervalCustom(), **kwargs
644+
) -> None:
645+
self._validate_default(default, IntervalArg.__args__) # type: ignore
646+
647+
if isinstance(default, timedelta):
648+
default = IntervalCustom.from_timedelta(default)
649+
650+
self.default = default
651+
kwargs.update({"default": default})
652+
super().__init__(**kwargs)
653+
654+
@property
655+
def column_type(self):
656+
engine_type = self._meta.table._meta.db.engine_type
657+
if engine_type == "postgres":
658+
return "INTERVAL"
659+
elif engine_type == "sqlite":
660+
# We can't use 'INTERVAL' because the type affinity in SQLite would
661+
# make it an integer - but we need a numeric field.
662+
# https://sqlite.org/datatype3.html#determination_of_column_affinity
663+
return "SECONDS"
664+
raise Exception("Unrecognized engine type")
665+
666+
615667
###############################################################################
616668

617669

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from __future__ import annotations
2+
import datetime
3+
import typing as t
4+
5+
from .base import Default
6+
7+
8+
class IntervalCustom(Default): # lgtm [py/missing-equals]
9+
def __init__(
10+
self,
11+
weeks: int = 0,
12+
days: int = 0,
13+
hours: int = 0,
14+
minutes: int = 0,
15+
seconds: int = 0,
16+
milliseconds: int = 0,
17+
microseconds: int = 0,
18+
):
19+
self.weeks = weeks
20+
self.days = days
21+
self.hours = hours
22+
self.minutes = minutes
23+
self.seconds = seconds
24+
self.milliseconds = milliseconds
25+
self.microseconds = microseconds
26+
27+
@property
28+
def timedelta(self):
29+
return datetime.timedelta(
30+
weeks=self.weeks,
31+
days=self.days,
32+
hours=self.hours,
33+
minutes=self.minutes,
34+
seconds=self.seconds,
35+
milliseconds=self.milliseconds,
36+
microseconds=self.microseconds,
37+
)
38+
39+
@property
40+
def postgres(self):
41+
value = self.get_postgres_interval_string(
42+
attributes=[
43+
"weeks",
44+
"days",
45+
"hours",
46+
"minutes",
47+
"seconds",
48+
"milliseconds",
49+
"microseconds",
50+
]
51+
)
52+
return f"'{value}'"
53+
54+
@property
55+
def sqlite(self):
56+
return self.timedelta.total_seconds()
57+
58+
def python(self):
59+
return self.timedelta
60+
61+
@classmethod
62+
def from_timedelta(cls, instance: datetime.timedelta):
63+
return cls(
64+
days=instance.days,
65+
seconds=instance.seconds,
66+
microseconds=instance.microseconds,
67+
)
68+
69+
70+
###############################################################################
71+
72+
IntervalArg = t.Union[
73+
IntervalCustom, None, datetime.timedelta,
74+
]
75+
76+
77+
__all__ = [
78+
"IntervalArg",
79+
"IntervalCustom",
80+
]

piccolo/engine/sqlite.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ def convert_date_in(value):
5151
return value.isoformat()
5252

5353

54+
def convert_timedelta_in(value):
55+
"""
56+
Converts the timedelta value being passed into sqlite.
57+
"""
58+
return value.total_seconds()
59+
60+
5461
def convert_numeric_out(value: bytes) -> Decimal:
5562
"""
5663
Convert float values into Decimals.
@@ -78,21 +85,30 @@ def convert_date_out(value: bytes) -> datetime.date:
7885

7986
def convert_time_out(value: bytes) -> datetime.time:
8087
"""
81-
If the value is a uuid, convert it to a UUID instance.
88+
If the value is a time, convert it to a UUID instance.
8289
"""
8390
return datetime.time.fromisoformat(value.decode("utf8"))
8491

8592

93+
def convert_seconds_out(value: bytes) -> datetime.timedelta:
94+
"""
95+
If the value is from a seconds column, convert it to a timedelta instance.
96+
"""
97+
return datetime.timedelta(seconds=float(value.decode("utf8")))
98+
99+
86100
sqlite3.register_converter("Numeric", convert_numeric_out)
87101
sqlite3.register_converter("Integer", convert_int_out)
88102
sqlite3.register_converter("UUID", convert_uuid_out)
89103
sqlite3.register_converter("Date", convert_date_out)
90104
sqlite3.register_converter("Time", convert_time_out)
105+
sqlite3.register_converter("Seconds", convert_seconds_out)
91106

92107
sqlite3.register_adapter(Decimal, convert_numeric_in)
93108
sqlite3.register_adapter(uuid.UUID, convert_uuid_in)
94109
sqlite3.register_adapter(datetime.time, convert_time_in)
95110
sqlite3.register_adapter(datetime.date, convert_date_in)
111+
sqlite3.register_adapter(datetime.timedelta, convert_timedelta_in)
96112

97113
###############################################################################
98114

@@ -317,6 +333,8 @@ def create_db(self, migrate=False):
317333
else:
318334
raise Exception(f"Database at {self.path} already exists")
319335
if migrate:
336+
# Commented out for now, as migrations for SQLite aren't as
337+
# well supported as Postgres.
320338
# from piccolo.commands.migration.forwards import (
321339
# ForwardsMigrationManager,
322340
# )

piccolo/query/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
import datetime
23
import itertools
34
from time import time
45
import typing as t
@@ -26,7 +27,9 @@ def default(obj):
2627
"""
2728
Used for handling edge cases which orjson can't serialise out of the box.
2829
"""
29-
if isinstance(obj, UUID):
30+
# This is the asyncpg UUID, not the builtin Python UUID, which orjon can
31+
# serialise out of the box.
32+
if isinstance(obj, (UUID, datetime.timedelta)):
3033
return str(obj)
3134
raise TypeError
3235

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ asgiref==3.2.10
33
asyncpg==0.21.0
44
black==19.10b0
55
colorama==0.4.3
6-
orjson==3.3.1
6+
orjson==3.4.0
77
Jinja2==2.11.2
88
targ==0.1.*

tests/columns/test_interval.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import datetime
2+
from unittest import TestCase
3+
4+
from piccolo.table import Table
5+
from piccolo.columns.column_types import Interval
6+
from piccolo.columns.defaults.interval import IntervalCustom
7+
8+
9+
class MyTable(Table):
10+
interval = Interval()
11+
12+
13+
class MyTableDefault(Table):
14+
interval = Interval(default=IntervalCustom(days=1))
15+
16+
17+
class TestInterval(TestCase):
18+
def setUp(self):
19+
MyTable.create_table().run_sync()
20+
21+
def tearDown(self):
22+
MyTable.alter().drop_table().run_sync()
23+
24+
def test_interval(self):
25+
# Test a range of different timedeltas
26+
intervals = [
27+
datetime.timedelta(weeks=1),
28+
datetime.timedelta(days=1),
29+
datetime.timedelta(hours=1),
30+
datetime.timedelta(minutes=1),
31+
datetime.timedelta(seconds=1),
32+
datetime.timedelta(milliseconds=1),
33+
datetime.timedelta(microseconds=1),
34+
]
35+
36+
for interval in intervals:
37+
# Make sure that saving it works OK.
38+
row = MyTable(interval=interval)
39+
row.save().run_sync()
40+
41+
# Make sure that fetching it back works OK.
42+
result = (
43+
MyTable.objects()
44+
.where(MyTable.id == row.id)
45+
.first()
46+
.run_sync()
47+
)
48+
self.assertEqual(result.interval, interval)
49+
50+
def test_interval_where_clause(self):
51+
"""
52+
Make sure a range of where clauses resolve correctly.
53+
"""
54+
interval = datetime.timedelta(hours=2)
55+
row = MyTable(interval=interval)
56+
row.save().run_sync()
57+
58+
result = (
59+
MyTable.objects()
60+
.where(MyTable.interval < datetime.timedelta(hours=3))
61+
.first()
62+
.run_sync()
63+
)
64+
self.assertTrue(result is not None)
65+
66+
result = (
67+
MyTable.objects()
68+
.where(MyTable.interval < datetime.timedelta(days=1))
69+
.first()
70+
.run_sync()
71+
)
72+
self.assertTrue(result is not None)
73+
74+
result = (
75+
MyTable.objects()
76+
.where(MyTable.interval > datetime.timedelta(minutes=1))
77+
.first()
78+
.run_sync()
79+
)
80+
self.assertTrue(result is not None)
81+
82+
result = (
83+
MyTable.exists()
84+
.where(MyTable.interval == datetime.timedelta(hours=2))
85+
.run_sync()
86+
)
87+
self.assertTrue(result)
88+
89+
90+
class TestIntervalDefault(TestCase):
91+
def setUp(self):
92+
MyTableDefault.create_table().run_sync()
93+
94+
def tearDown(self):
95+
MyTableDefault.alter().drop_table().run_sync()
96+
97+
def test_interval(self):
98+
row = MyTableDefault()
99+
row.save().run_sync()
100+
101+
result = MyTableDefault.objects().first().run_sync()
102+
self.assertTrue(result.interval.days == 1)

0 commit comments

Comments
 (0)