Skip to content

Commit ca08bed

Browse files
committed
can use operators on update queries
1 parent 4528c22 commit ca08bed

File tree

3 files changed

+165
-2
lines changed

3 files changed

+165
-2
lines changed

docs/src/piccolo/query_types/update.rst

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,36 @@ This is used to update any rows in the table which match the criteria.
1414
>>> ).run_sync()
1515
[]
1616
17+
You can also add / subtract / multiply / divide values:
18+
19+
.. code-block:: python
20+
21+
# Add 100 to the popularity of each band:
22+
Band.update({
23+
Band.popularity: Band.popularity + 100
24+
}).run_sync()
25+
26+
# Decrease the popularity of each band by 100.
27+
Band.update({
28+
Band.popularity: Band.popularity - 100
29+
}).run_sync()
30+
31+
# Multiply the popularity of each band by 10.
32+
Band.update({
33+
Band.popularity: Band.popularity * 10
34+
}).run_sync()
35+
36+
# Divide the popularity of each band by 10.
37+
Band.update({
38+
Band.popularity: Band.popularity / 10
39+
}).run_sync()
40+
41+
# You can also use the operators in reverse:
42+
Band.update({
43+
Band.popularity: 2000 - Band.popularity
44+
}).run_sync()
45+
46+
1747
Query clauses
1848
-------------
1949

piccolo/columns/column_types.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import uuid
66

77
from piccolo.columns.base import Column, OnDelete, OnUpdate, ForeignKeyMeta
8-
from piccolo.querystring import Unquoted
8+
from piccolo.querystring import Unquoted, QueryString
99

1010
if t.TYPE_CHECKING:
1111
from piccolo.table import Table # noqa
@@ -82,6 +82,49 @@ def __init__(self, default: int = None, **kwargs) -> None:
8282
kwargs.update({"default": default})
8383
super().__init__(**kwargs)
8484

85+
def _get_querystring(self, operator: str, value: int, reverse=False):
86+
"""
87+
Used in update queries - for example:
88+
89+
await Band.update({Band.popularity: Band.popularity + 100}).run()
90+
"""
91+
if not isinstance(value, (int, float)):
92+
raise ValueError("Only integers and floats can be added.")
93+
if reverse:
94+
return QueryString(f"{{}} {operator} {self._meta.name} ", value)
95+
else:
96+
return QueryString(f"{self._meta.name} {operator} {{}}", value)
97+
98+
def __add__(self, value: int):
99+
return self._get_querystring("+", value)
100+
101+
def __radd__(self, value: int):
102+
return self._get_querystring("+", value, reverse=True)
103+
104+
def __sub__(self, value: int):
105+
return self._get_querystring("-", value)
106+
107+
def __rsub__(self, value: int):
108+
return self._get_querystring("-", value, reverse=True)
109+
110+
def __mul__(self, value: int):
111+
return self._get_querystring("*", value)
112+
113+
def __rmul__(self, value: int):
114+
return self._get_querystring("*", value, reverse=True)
115+
116+
def __truediv__(self, value: int):
117+
return self._get_querystring("/", value)
118+
119+
def __rtruediv__(self, value: int):
120+
return self._get_querystring("/", value, reverse=True)
121+
122+
def __floordiv__(self, value: int):
123+
return self._get_querystring("/", value)
124+
125+
def __rfloordiv__(self, value: int):
126+
return self._get_querystring("/", value, reverse=True)
127+
85128

86129
###############################################################################
87130
# BigInt and SmallInt only exist on Postgres. SQLite treats them the same as

tests/table/test_update.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ class TestUpdate(DBTestCase):
66
def test_update(self):
77
self.insert_rows()
88

9-
Band.update().values({Band.name: "Pythonistas3"}).where(
9+
Band.update({Band.name: "Pythonistas3"}).where(
1010
Band.name == "Pythonistas"
1111
).run_sync()
1212

@@ -18,3 +18,93 @@ def test_update(self):
1818
print(f"response = {response}")
1919

2020
self.assertEqual(response, [{"name": "Pythonistas3"}])
21+
22+
def test_update_values(self):
23+
self.insert_rows()
24+
25+
Band.update({Band.name: "Pythonistas3"}).where(
26+
Band.name == "Pythonistas"
27+
).run_sync()
28+
29+
response = (
30+
Band.select(Band.name)
31+
.where(Band.name == "Pythonistas3")
32+
.run_sync()
33+
)
34+
print(f"response = {response}")
35+
36+
self.assertEqual(response, [{"name": "Pythonistas3"}])
37+
38+
39+
class TestUpdateOperators(DBTestCase):
40+
def test_add(self):
41+
self.insert_row()
42+
43+
Band.update({Band.popularity: Band.popularity + 10}).run_sync()
44+
45+
response = Band.select(Band.popularity).first().run_sync()
46+
47+
self.assertEqual(response["popularity"], 1010)
48+
49+
def test_radd(self):
50+
self.insert_row()
51+
52+
Band.update({Band.popularity: 10 + Band.popularity}).run_sync()
53+
54+
response = Band.select(Band.popularity).first().run_sync()
55+
56+
self.assertEqual(response["popularity"], 1010)
57+
58+
def test_sub(self):
59+
self.insert_row()
60+
61+
Band.update({Band.popularity: Band.popularity - 10}).run_sync()
62+
63+
response = Band.select(Band.popularity).first().run_sync()
64+
65+
self.assertEqual(response["popularity"], 990)
66+
67+
def test_rsub(self):
68+
self.insert_row()
69+
70+
Band.update({Band.popularity: 1100 - Band.popularity}).run_sync()
71+
72+
response = Band.select(Band.popularity).first().run_sync()
73+
74+
self.assertEqual(response["popularity"], 100)
75+
76+
def test_mul(self):
77+
self.insert_row()
78+
79+
Band.update({Band.popularity: Band.popularity * 2}).run_sync()
80+
81+
response = Band.select(Band.popularity).first().run_sync()
82+
83+
self.assertEqual(response["popularity"], 2000)
84+
85+
def test_rmul(self):
86+
self.insert_row()
87+
88+
Band.update({Band.popularity: 2 * Band.popularity}).run_sync()
89+
90+
response = Band.select(Band.popularity).first().run_sync()
91+
92+
self.assertEqual(response["popularity"], 2000)
93+
94+
def test_div(self):
95+
self.insert_row()
96+
97+
Band.update({Band.popularity: Band.popularity / 10}).run_sync()
98+
99+
response = Band.select(Band.popularity).first().run_sync()
100+
101+
self.assertEqual(response["popularity"], 100)
102+
103+
def test_rdiv(self):
104+
self.insert_row()
105+
106+
Band.update({Band.popularity: 1000 / Band.popularity}).run_sync()
107+
108+
response = Band.select(Band.popularity).first().run_sync()
109+
110+
self.assertEqual(response["popularity"], 1)

0 commit comments

Comments
 (0)