Skip to content

Commit 4555388

Browse files
committed
using entrypoint script
1 parent 0065eb5 commit 4555388

File tree

3 files changed

+227
-220
lines changed

3 files changed

+227
-220
lines changed

piccolo/main.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
import asyncio
2+
import datetime
3+
import importlib.util
4+
import os
5+
import sys
6+
import typing as t
7+
from types import ModuleType
8+
9+
import click
10+
11+
from piccolo.engine import PostgresEngine
12+
from piccolo.migrations.template import TEMPLATE
13+
from piccolo.migrations.table import Migration
14+
15+
16+
MIGRATIONS_FOLDER = os.path.join(os.getcwd(), 'migrations')
17+
MIGRATION_MODULES: t.Dict[str, ModuleType] = {}
18+
19+
20+
def _create_migrations_folder() -> bool:
21+
"""
22+
Creates the folder that migrations live in. Returns True/False depending
23+
on whether it was created or not.
24+
"""
25+
if os.path.exists(MIGRATIONS_FOLDER):
26+
return False
27+
else:
28+
os.mkdir(MIGRATIONS_FOLDER)
29+
for filename in ('__init__.py', 'config.py'):
30+
with open(os.path.join(MIGRATIONS_FOLDER, filename), 'w'):
31+
pass
32+
return True
33+
34+
35+
def _create_new_migration() -> None:
36+
"""
37+
Creates a new migration file on disk.
38+
"""
39+
_id = datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
40+
path = os.path.join(MIGRATIONS_FOLDER, f'{_id}.py')
41+
with open(path, 'w') as f:
42+
f.write(TEMPLATE.format(migration_id=_id))
43+
44+
45+
###############################################################################
46+
47+
@click.command()
48+
def new():
49+
"""
50+
Creates a new file like migrations/0001_add_user_table.py
51+
"""
52+
print('Creating new migration ...')
53+
_create_migrations_folder()
54+
_create_new_migration()
55+
56+
57+
###############################################################################
58+
59+
def _create_migration_table() -> bool:
60+
"""
61+
Creates the migration table in the database. Returns True/False depending
62+
on whether it was created or not.
63+
"""
64+
if not Migration.table_exists().run_sync():
65+
Migration.create().run_sync()
66+
return True
67+
return False
68+
69+
70+
def _get_migrations_which_ran() -> t.List[str]:
71+
"""
72+
Returns the names of migrations which have already run, by inspecing the
73+
database.
74+
"""
75+
return [i['name'] for i in Migration.select('name').run_sync()]
76+
77+
78+
def _get_migration_modules() -> None:
79+
"""
80+
"""
81+
folder_contents = os.listdir(MIGRATIONS_FOLDER)
82+
excluded = ('__init__.py', 'config.py', '__pycache__')
83+
migration_names = [
84+
i.split('.py')[0] for i in folder_contents if i not in excluded
85+
]
86+
modules = [importlib.import_module(name) for name in migration_names]
87+
global MIGRATION_MODULES
88+
for m in modules:
89+
_id = getattr(m, 'ID', None)
90+
if _id:
91+
MIGRATION_MODULES[_id] = m
92+
93+
94+
def _get_migration_ids() -> t.List[str]:
95+
return list(MIGRATION_MODULES.keys())
96+
97+
98+
def _get_config() -> dict:
99+
"""
100+
A config file is required for the database credentials.
101+
"""
102+
sys.path.insert(0, MIGRATIONS_FOLDER)
103+
104+
config_file = os.path.join(MIGRATIONS_FOLDER, 'config.py')
105+
if not os.path.exists(config_file):
106+
raise Exception(f"Can't find config.py in {MIGRATIONS_FOLDER}")
107+
108+
config = importlib.import_module('config')
109+
110+
db = getattr(config, 'DB', None)
111+
if not db:
112+
raise Exception('config.py is missing a DB dictionary.')
113+
return db
114+
115+
116+
@click.command()
117+
def forwards():
118+
"""
119+
Runs any migrations which haven't been run yet, or up to a specific
120+
migration.
121+
"""
122+
print('Running migrations ...')
123+
sys.path.insert(0, os.getcwd())
124+
125+
Migration.Meta.db = PostgresEngine(_get_config())
126+
127+
_create_migration_table()
128+
129+
already_ran = _get_migrations_which_ran()
130+
print(f'Already ran = {already_ran}')
131+
132+
# TODO - stop using globals ...
133+
_get_migration_modules()
134+
ids = _get_migration_ids()
135+
print(f'Migration ids = {ids}')
136+
137+
for _id in (set(ids) - set(already_ran)):
138+
asyncio.run(
139+
MIGRATION_MODULES[_id].forwards()
140+
)
141+
142+
print(f'Ran {_id}')
143+
Migration.insert().add(
144+
Migration(name=_id)
145+
).run_sync()
146+
147+
148+
###############################################################################
149+
150+
@click.command()
151+
@click.argument('migration_name')
152+
def backwards(migration_name: str):
153+
"""
154+
Undo migrations up to a specific migrations.
155+
156+
- make sure the migration name is valid
157+
- work out which to undo, and in which order
158+
- ask for confirmation
159+
- apply the undo operations one by one
160+
"""
161+
# Get the list from disk ...
162+
sys.path.insert(0, os.getcwd())
163+
_get_config() # Just required for path manipulation - needs changing
164+
_get_migration_modules()
165+
166+
Migration.Meta.db = PostgresEngine(_get_config())
167+
168+
_create_migration_table()
169+
170+
migration_ids = _get_migration_ids()
171+
172+
if migration_name not in migration_ids:
173+
print(f'Unrecognized migration name - must be one of {migration_ids}')
174+
175+
_continue = input('About to undo the migrations - press y to continue.')
176+
if _continue == 'y':
177+
print('Undoing migrations')
178+
179+
_sorted = sorted(list(MIGRATION_MODULES.keys()))
180+
_sorted = _sorted[_sorted.index(migration_name):]
181+
_sorted.reverse()
182+
183+
already_ran = _get_migrations_which_ran()
184+
185+
for s in _sorted:
186+
if s not in already_ran:
187+
print(f"Migration {s} hasn't run yet, can't undo!")
188+
sys.exit(1)
189+
190+
print(migration_name)
191+
asyncio.run(
192+
MIGRATION_MODULES[s].backwards() # type: ignore
193+
)
194+
195+
Migration.delete().where(
196+
Migration.name == s
197+
).run_sync()
198+
else:
199+
print('Not proceeding.')
200+
201+
202+
###############################################################################
203+
204+
@click.group('migration')
205+
def cli():
206+
pass
207+
208+
209+
cli.add_command(new)
210+
cli.add_command(forwards)
211+
cli.add_command(backwards)
212+
213+
214+
def main():
215+
cli()
216+
217+
218+
###############################################################################
219+
220+
221+
if __name__ == '__main__':
222+
main()

0 commit comments

Comments
 (0)