Skip to content

Commit c3edb1a

Browse files
committed
added support for aliases (SELECT foo as bar)
1 parent 7183737 commit c3edb1a

File tree

6 files changed

+107
-11
lines changed

6 files changed

+107
-11
lines changed

piccolo/columns/base.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22
from abc import ABCMeta, abstractmethod
3+
import copy
34
from dataclasses import dataclass, field
45
import datetime
56
import decimal
@@ -226,6 +227,8 @@ def __init__(
226227
required=required,
227228
)
228229

230+
self.alias: t.Optional[str] = None
231+
229232
def _validate_default(
230233
self,
231234
default: t.Any,
@@ -306,6 +309,20 @@ def __ne__(self, value) -> Where: # type: ignore
306309
def __hash__(self):
307310
return hash(self._meta.name)
308311

312+
def as_alias(self, name: str) -> Column:
313+
"""
314+
Allows column names to be changed in the result of a select.
315+
316+
For example:
317+
318+
>>> await Band.select(Band.name.as_alias('title')).run()
319+
{'title': 'Pythonistas'}
320+
321+
"""
322+
column = copy.deepcopy(self)
323+
column.alias = name
324+
return column
325+
309326
def get_default_value(self) -> t.Any:
310327
"""
311328
If the column has a default attribute, return it. If it's callable,
@@ -323,7 +340,14 @@ def get_select_string(self, engine_type: str, just_alias=False) -> str:
323340
"""
324341
How to refer to this column in a SQL query.
325342
"""
326-
return self._meta.get_full_name(just_alias=just_alias)
343+
if self.alias is None:
344+
return self._meta.get_full_name(just_alias=just_alias)
345+
else:
346+
original_name = self._meta.get_full_name(just_alias=True)
347+
return f"{original_name} AS {self.alias}"
348+
349+
def get_where_string(self, engine_type: str) -> str:
350+
return self.get_select_string(engine_type=engine_type, just_alias=True)
327351

328352
def get_sql_value(self, value: t.Any) -> t.Any:
329353
"""

piccolo/columns/column_types.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,10 +1094,11 @@ def arrow(self, key: str) -> JSON:
10941094
return self
10951095

10961096
def get_select_string(self, engine_type: str, just_alias=False) -> str:
1097-
select_string = super().get_select_string(
1098-
engine_type=engine_type, just_alias=just_alias
1099-
)
1097+
select_string = self._meta.get_full_name(just_alias=just_alias)
11001098
if self.json_operator is None:
11011099
return select_string
11021100
else:
1103-
return f"{select_string} {self.json_operator} AS {self._meta.name}"
1101+
if self.alias is None:
1102+
return f"{select_string} {self.json_operator}"
1103+
else:
1104+
return f"{select_string} {self.json_operator} AS {self.alias}"

piccolo/columns/combination.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class Where(CombinableMixin):
6060

6161
def __init__(
6262
self,
63-
column: "Column",
63+
column: Column,
6464
value: t.Any = UNDEFINED,
6565
values: t.Union[Iterable, Undefined] = UNDEFINED,
6666
operator: t.Type[ComparisonOperator] = ComparisonOperator,
@@ -91,7 +91,9 @@ def querystring(self) -> QueryString:
9191
args.append(self.values_querystring)
9292

9393
template = self.operator.template.format(
94-
name=self.column._meta.get_full_name(just_alias=True),
94+
name=self.column.get_where_string(
95+
engine_type=self.column._meta.engine_type, just_alias=True
96+
),
9597
value="{}",
9698
values="{}",
9799
)

piccolo/query/methods/select.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,9 @@ def querystrings(self) -> t.Sequence[QueryString]:
248248

249249
#######################################################################
250250

251-
select = "SELECT DISTINCT" if self.distinct else "SELECT"
251+
select = (
252+
"SELECT DISTINCT" if self.distinct_delegate._distinct else "SELECT"
253+
)
252254
query = f"{select} {columns_str} FROM {self.table._meta.tablename}"
253255

254256
for join in joins:

tests/columns/test_jsonb.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class MyTable(Table):
1111

1212

1313
@postgres_only
14-
class TestJSON(TestCase):
14+
class TestJSONB(TestCase):
1515
def setUp(self):
1616
MyTable.create_table().run_sync()
1717

@@ -32,4 +32,29 @@ def test_arrow(self):
3232
"""
3333
MyTable(json='{"a": 1}').save().run_sync()
3434
row = MyTable.select(MyTable.json.arrow("a")).first().run_sync()
35-
self.assertEqual(row["json"], "1")
35+
self.assertEqual(row["?column?"], "1")
36+
37+
def test_arrow_as_alias(self):
38+
"""
39+
Test using the arrow function to retrieve a subset of the JSON.
40+
"""
41+
MyTable(json='{"a": 1}').save().run_sync()
42+
row = (
43+
MyTable.select(MyTable.json.arrow("a").as_alias("a"))
44+
.first()
45+
.run_sync()
46+
)
47+
self.assertEqual(row["a"], "1")
48+
49+
def test_arrow_where(self):
50+
"""
51+
Make sure the arrow function can be used within a WHERE clause.
52+
"""
53+
MyTable(json='{"a": 1}').save().run_sync()
54+
self.assertEqual(
55+
MyTable.count().where(MyTable.json.arrow("a") == "1").run_sync(), 1
56+
)
57+
58+
self.assertEqual(
59+
MyTable.count().where(MyTable.json.arrow("a") == "2").run_sync(), 0
60+
)

tests/table/test_select.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,14 @@ def test_distinct(self):
285285
self.insert_rows()
286286
self.insert_rows()
287287

288+
response = (
289+
Band.select(Band.name).where(Band.name == "Pythonistas").run_sync()
290+
)
291+
292+
self.assertTrue(
293+
response == [{"name": "Pythonistas"}, {"name": "Pythonistas"}]
294+
)
295+
288296
response = (
289297
Band.select(Band.name)
290298
.where(Band.name == "Pythonistas")
@@ -410,9 +418,43 @@ def test_call_chain(self):
410418
"""
411419
Make sure the call chain lengths are the correct size.
412420
"""
413-
# self.assertEqual(len(Concert.band_1.name._meta.call_chain), 1)
421+
self.assertEqual(len(Concert.band_1.name._meta.call_chain), 1)
414422
self.assertEqual(len(Concert.band_1.manager.name._meta.call_chain), 2)
415423

424+
def test_as_alias(self):
425+
"""
426+
Make sure we can specify aliases for the columns.
427+
"""
428+
self.insert_row()
429+
response = Band.select(Band.name.as_alias("title")).run_sync()
430+
self.assertEqual(response, [{"title": "Pythonistas"}])
431+
432+
def test_as_alias_with_join(self):
433+
"""
434+
Make sure we can specify aliases for the column, when performing a
435+
join.
436+
"""
437+
self.insert_row()
438+
response = Band.select(
439+
Band.manager.name.as_alias("manager_name")
440+
).run_sync()
441+
self.assertEqual(response, [{"manager_name": "Guido"}])
442+
443+
def test_as_alias_with_where_clause(self):
444+
"""
445+
Make sure we can specify aliases for the column, when the column is
446+
also being used in a where clause.
447+
"""
448+
self.insert_row()
449+
response = (
450+
Band.select(Band.name, Band.manager.name.as_alias("manager_name"))
451+
.where(Band.manager.name == "Guido")
452+
.run_sync()
453+
)
454+
self.assertEqual(
455+
response, [{"name": "Pythonistas", "manager_name": "Guido"}]
456+
)
457+
416458

417459
class TestSelectSecret(TestCase):
418460
def setUp(self):

0 commit comments

Comments
 (0)