diff --git a/.github/ISSUE_TEMPLATES/features.yml b/.github/ISSUE_TEMPLATES/features.yml new file mode 100644 index 000000000..143c336c3 --- /dev/null +++ b/.github/ISSUE_TEMPLATES/features.yml @@ -0,0 +1,22 @@ +name: Features 💡 +description: Suggest a new feature +title: "[Feature]: Describe your feature idea here" +labels: ["goal:addition"] +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to fill out this suggestion form! + - type: input + attributes: + label: Suggestion + description: Can you please elaborate on your suggestion? + placeholder: I would like to see... + validations: + required: true + - type: textarea + id: multimedia + attributes: + label: Relevant media + description: You can add any media which helps explain your idea. + placeholder: Please upload your multimedia (screenshots and video) here. diff --git a/.github/ISSUE_TEMPLATES/issues.yml b/.github/ISSUE_TEMPLATES/issues.yml new file mode 100644 index 000000000..81e8145b8 --- /dev/null +++ b/.github/ISSUE_TEMPLATES/issues.yml @@ -0,0 +1,22 @@ +name: Bug 🐛 +description: Report a bug +title: "[Issue]: Describe the bug here" +labels: ["goal:fix, priority:medium"] +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to fill out this bug report! + - type: input + attributes: + label: Bug description + description: Can you please describe the bug in more detail? + placeholder: Tell us what you found. + validations: + required: true + - type: textarea + id: multimedia + attributes: + label: Relevant log output + description: Please paste any relevant log output. + placeholder: Please upload your multimedia (screenshots and video) here. diff --git a/.github/ISSUE_TEMPLATES/support.yml b/.github/ISSUE_TEMPLATES/support.yml new file mode 100644 index 000000000..4cecf39b5 --- /dev/null +++ b/.github/ISSUE_TEMPLATES/support.yml @@ -0,0 +1,16 @@ +name: Support 💁‍♂️ +description: Raise a support ticket +title: "[Support]: Describe your issue / support requirement here" +labels: [support] +body: + - type: markdown + attributes: + value: | + Thanks for reaching out! + - type: input + attributes: + label: Support ticket information + description: Please elaborate on your support requirement. + placeholder: I would like to see... + validations: + required: true diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index ee4569f63..0f14f4210 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -11,10 +11,10 @@ jobs: runs-on: "ubuntu-latest" steps: - - uses: "actions/checkout@v2" - - uses: "actions/setup-python@v1" + - uses: "actions/checkout@v3" + - uses: "actions/setup-python@v5" with: - python-version: 3.7 + python-version: 3.13 - name: "Install dependencies" run: "pip install -r requirements/dev-requirements.txt" - name: "Publish to PyPI" diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 691aeff55..2f6655c2d 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -1,22 +1,27 @@ name: Test Suite on: - push: - branches: ["master"] - pull_request: - branches: ["master"] + push: + branches: ["master", "v1"] + paths-ignore: + - "docs/**" + pull_request: + branches: ["master", "v1"] + paths-ignore: + - "docs/**" jobs: linters: runs-on: ubuntu-latest + timeout-minutes: 60 strategy: matrix: - python-version: [3.7, 3.8, 3.9] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -28,12 +33,60 @@ jobs: - name: Lint run: ./scripts/lint.sh + integration: + runs-on: ubuntu-latest + timeout-minutes: 60 + strategy: + matrix: + # These tests are slow, so we only run on the latest Python + # version. + python-version: ["3.13"] + postgres-version: [17] + services: + postgres: + image: postgres:${{ matrix.postgres-version }} + env: + POSTGRES_PASSWORD: postgres + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements/requirements.txt + pip install -r requirements/test-requirements.txt + pip install -r requirements/extras/postgres.txt + - name: Setup postgres + run: | + export PGPASSWORD=postgres + psql -h localhost -c 'CREATE DATABASE piccolo;' -U postgres + psql -h localhost -c "CREATE USER piccolo PASSWORD 'piccolo';" -U postgres + psql -h localhost -c "GRANT ALL PRIVILEGES ON DATABASE piccolo TO piccolo;" -U postgres + psql -h localhost -c "CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";" -d piccolo -U postgres + - name: Run integration tests + run: ./scripts/test-integration.sh + env: + PG_HOST: localhost + PG_DATABASE: piccolo + PG_PASSWORD: postgres + postgres: runs-on: ubuntu-latest + timeout-minutes: 60 strategy: matrix: - python-version: [3.7, 3.8, 3.9] - postgres-version: [9.6, 10, 11, 12, 13] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] + postgres-version: [13, 14, 15, 16, 17, 18] # Service containers to run with `container-job` services: @@ -54,9 +107,9 @@ jobs: - 5432:5432 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -81,18 +134,53 @@ jobs: PG_PASSWORD: postgres - name: Upload coverage uses: codecov/codecov-action@v1 - if: matrix.python-version == '3.7' + if: matrix.python-version == '3.13' + + cockroach: + runs-on: ubuntu-latest + timeout-minutes: 60 + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] + cockroachdb-version: ["v24.1.0"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements/requirements.txt + pip install -r requirements/test-requirements.txt + pip install -r requirements/extras/postgres.txt + - name: Setup CockroachDB + run: | + wget -qO- https://binaries.cockroachdb.com/cockroach-${{ matrix.cockroachdb-version }}.linux-amd64.tgz | tar xz + ./cockroach-${{ matrix.cockroachdb-version }}.linux-amd64/cockroach start-single-node --insecure --background + ./cockroach-${{ matrix.cockroachdb-version }}.linux-amd64/cockroach sql --insecure -e 'create database piccolo;' + + - name: Test with pytest, CockroachDB + run: ./scripts/test-cockroach.sh + env: + PG_HOST: localhost + PG_DATABASE: piccolo + - name: Upload coverage + uses: codecov/codecov-action@v1 + if: matrix.python-version == '3.13' sqlite: runs-on: ubuntu-latest + timeout-minutes: 60 strategy: matrix: - python-version: [3.7, 3.8, 3.9] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -105,4 +193,4 @@ jobs: run: ./scripts/test-sqlite.sh - name: Upload coverage uses: codecov/codecov-action@v1 - if: matrix.python-version == '3.7' + if: matrix.python-version == '3.13' diff --git a/.gitignore b/.gitignore index 5356d08b8..b318c8ebb 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,7 @@ build/ .doctrees/ .vscode/ piccolo.egg-info/ -build/ +.idea/ dist/ piccolo.sqlite _build/ @@ -18,3 +18,9 @@ htmlcov/ prof/ .env/ .venv/ +result.json + +# CockroachDB +cockroach-data/ +heap_profiler/ +goroutine_dump/ diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 000000000..2e91ea548 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,21 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.11" + +sphinx: + configuration: docs/src/conf.py + +formats: + - pdf + - epub + +python: + install: + - requirements: requirements/readthedocs-requirements.txt diff --git a/CHANGES b/CHANGES deleted file mode 100644 index dbb0d182a..000000000 --- a/CHANGES +++ /dev/null @@ -1,690 +0,0 @@ -Changes -======= - -0.33.0 ------- -Fix for auto migrations when using custom primary keys (thanks to @adriangb and -@aminalaee for investigating this issue). - -0.32.0 ------- -Migrations can now have a description, which is shown when using -``piccolo migrations check``. This makes migrations easier to identify (thanks -to @davidolrik for the idea). - -0.31.0 ------- -Added an ``all_columns`` method, to make it easier to retrieve all related -columns when doing a join. For example: - -.. code-block:: python - - await Band.select(Band.name, *Band.manager.all_columns()).first().run() - -Changed the instructions for installing additional dependencies, so they're -wrapped in quotes, to make sure it works on ZSH (i.e. -``pip install 'piccolo[postgres]'`` instead of -``pip install piccolo[postgres]``). - -0.30.0 ------- -The database drivers are now installed separately. For example: -``pip install piccolo[postgres]`` (courtesy @aminalaee). - -For some users this might be a **breaking change** - please make sure that for -existing Piccolo projects, you have either ``asyncpg``, or -``piccolo[postgres]`` in your ``requirements.txt`` file. - -0.29.0 ------- -The user can now specify the primary key column (courtesy @aminalaee). For -example: - -.. code-block:: python - - class RecordingStudio(Table): - pk = UUID(primary_key=True) - -The BlackSheep template generated by `piccolo asgi new` now supports mounting -of the Piccolo Admin (courtesy @sinisaos). - -0.28.0 ------- -Added aggregations functions, such as ``Sum``, ``Min``, ``Max`` and ``Avg``, -for use in select queries (courtesy @sinisaos). - -0.27.0 ------- -Added uvloop as an optional dependency, installed via `pip install piccolo[uvloop]` -(courtesy @aminalaee). uvloop is a faster implementation of the asyncio event -loop found in Python's standard library. When uvloop is installed, Piccolo will -use it to increase the performance of the Piccolo CLI, and web servers such as -Uvicorn will use it to increase the performance of your ASGI app. - -0.26.0 ------- -Added ``eq`` and ``ne`` methods to the ``Boolean`` column, which can be used -if linters complain about using ``SomeTable.some_column == True``. - -0.25.0 ------- - * Changed the migration IDs, so the timestamp now includes microseconds. This - is to make clashing migration IDs much less likely. - * Added a lot of end-to-end tests for migrations, which revealed some bugs - in ``Column`` defaults. - -0.24.1 ------- -A bug fix for migrations. See `issue 123 `_ -for more information. - -0.24.0 ------- -Lots of improvements to ``JSON`` and ``JSONB`` columns. Piccolo will now -automatically convert between Python types and JSON strings. For example, with -this schema: - -.. code-block:: python - - class RecordingStudio(Table): - name = Varchar() - facilities = JSON() - -We can now do the following: - -.. code-block:: python - - RecordingStudio( - name="Abbey Road", - facilities={'mixing_desk': True} # Will automatically be converted to a JSON string - ).save().run_sync() - -Similarly, when fetching data from a JSON column, Piccolo can now automatically -deserialise it. - -.. code-block:: python - - >>> RecordingStudio.select().output(load_json=True).run_sync() - [{'id': 1, 'name': 'Abbey Road', 'facilities': {'mixing_desk': True}] - - >>> studio = RecordingStudio.objects().first().output(load_json=True).run_sync() - >>> studio.facilities - {'mixing_desk': True} - -0.23.0 ------- -Added the ``create_table_class`` function, which can be used to create -``Table`` subclasses at runtime. This was required to fix an existing bug, -which was effecting migrations (see `issue 111 `_ -for more details). - -0.22.0 ------- - * An error is now raised if a user tries to create a Piccolo app using - ``piccolo app new`` with the same name as a builtin Python module, as it - will cause strange bugs. - * Fixing a strange bug where using an expression such as - ``Concert.band_1.manager.id`` in a query would cause an error. It only - happened if multiple joins were involved, and the last column in the chain - was ``id``. - * ``where`` clauses can now accept ``Table`` instances. For example: - ``await Band.select().where(Band.manager == some_manager).run()``, instead - of having to explicity reference the ``id``. - -0.21.2 ------- -Fixing a bug with serialising ``Enum`` instances in migrations. For example: -``Varchar(default=Colour.red)``. - -0.21.1 ------- -Fix missing imports in FastAPI and Starlette app templates. - -0.21.0 ------- - * Added a ``freeze`` method to ``Query``. - * Added BlackSheep as an option to ``piccolo asgi new``. - -0.20.0 ------- -Added ``choices`` option to ``Column``. - -0.19.1 ------- - * Added ``piccolo user change_permissions`` command. - * Added aliases for CLI commands. - -0.19.0 ------- -Changes to the ``BaseUser`` table - added a ``superuser``, and ``last_login`` -column. These are required for upgrades to Piccolo Admin. - -If you're using migrations, then running ``piccolo migrations forwards all`` -should add these new columns for you. - -If not using migrations, the ``BaseUser`` table can be upgraded using the -following DDL statements: - -.. code-block:: sql - - ALTER TABLE piccolo_user ADD COLUMN "superuser" BOOLEAN NOT NULL DEFAULT false - ALTER TABLE piccolo_user ADD COLUMN "last_login" TIMESTAMP DEFAULT null - -0.18.4 ------- - * Fixed a bug when multiple tables inherit from the same mixin (thanks to - @brnosouza). - * Added a ``log_queries`` option to ``PostgresEngine``, which is useful during - debugging. - * Added the `inflection` library for converting ``Table`` class names to - database table names. Previously, a class called ``TableA`` would wrongly - have a table called ``table`` instead of ``table_a``. - * Fixed a bug with ``SerialisedBuiltin.__hash__`` not returning a number, - which could break migrations (thanks to @sinisaos). - -0.18.3 ------- -Improved ``Array`` column serialisation - needed to fix auto migrations. - -0.18.2 ------- -Added support for filtering ``Array`` columns. - -0.18.1 ------- -Add the ``Array`` column type as a top level import in ``piccolo.columns``. - -0.18.0 ------- - * Refactored ``forwards`` and ``backwards`` commands for migrations, to make - them easier to run programatically. - * Added a simple ``Array`` column type. - * ``table_finder`` now works if just a string is passed in, instead of having - to pass in an array of strings. - -0.17.5 ------- -Catching database connection exceptions when starting the default ASGI app -created with ``piccolo asgi new`` - these errors exist if the Postgres -database hasn't been created yet. - -0.17.4 ------- -Added a ``help_text`` option to the ``Table`` metaclass. This is used in -Piccolo Admin to show tooltips. - -0.17.3 ------- -Added a ``help_text`` option to the ``Column`` constructor. This is used in -Piccolo Admin to show tooltips. - -0.17.2 ------- - * Exposing ``index_type`` in the ``Column`` constructor. - * Fixing a typo with ``start_connection_pool` and ``close_connection_pool`` - - thanks to paolodina for finding this. - * Fixing a typo in the ``PostgresEngine`` docs - courtesy of paolodina. - -0.17.1 ------- -Fixing a bug with ``SchemaSnapshot`` if column types were changed in migrations -- the snapshot didn't reflect the changes. - -0.17.0 ------- - * Migrations now directly import ``Column`` classes - this allows users to - create custom ``Column`` subclasses. Migrations previously only worked with - the builtin column types. - * Migrations now detect if the column type has changed, and will try and - convert it automatically. - -0.16.5 ------- -The Postgres extensions that ``PostgresEngine`` tries to enable at startup -can now be configured. - -0.16.4 ------- - * Fixed a bug with ``MyTable.column != None`` - * Added ``is_null`` and ``is_not_null`` methods, to avoid linting issues when - comparing with None. - -0.16.3 ------- - * Added ``WhereRaw``, so raw SQL can be used in where clauses. - * ``piccolo shell run`` now uses syntax highlighting - courtesy of Fingel. - -0.16.2 ------- -Reordering the dependencies in requirements.txt when using ``piccolo asgi new`` -as the latest FastAPI and Starlette versions are incompatible. - -0.16.1 ------- -Added ``Timestamptz`` column type, for storing datetimes which are timezone -aware. - -0.16.0 ------- - * Fixed a bug with creating a ``ForeignKey`` column with ``references="self"`` - in auto migrations. - * Changed migration file naming, so there are no characters in there which - are unsupported on Windows. - -0.15.1 ------- -Changing the status code when creating a migration, and no changes were -detected. It now returns a status code of 0, so it doesn't fail build scripts. - -0.15.0 ------- -Added ``Bytea`` / ``Blob`` column type. - -0.14.13 -------- -Fixing a bug with migrations which drop column defaults. - -0.14.12 -------- - * Fixing a bug where re-running ``Table.create(if_not_exists=True)`` would - fail if it contained columns with indexes. - * Raising a ``ValueError`` if a relative path is provided to ``ForeignKey`` - ``references``. For example, ``.tables.Manager``. The paths must be absolute - for now. - -0.14.11 -------- -Fixing a bug with ``Boolean`` column defaults, caused by the ``Table`` -metaclass not being explicit enough when checking falsy values. - -0.14.10 -------- - * The ``ForeignKey`` ``references`` argument can now be specified using a - string, or a ``LazyTableReference`` instance, rather than just a ``Table`` - subclass. This allows a ``Table`` to be specified which is in a Piccolo app, - or Python module. The ``Table`` is only loaded after imports have completed, - which prevents circular import issues. - * Faster column copying, which is important when specifying joins, e.g. - ``await Band.select(Band.manager.name).run()``. - * Fixed a bug with migrations and foreign key contraints. - -0.14.9 ------- -Modified the exit codes for the ``forwards`` and ``backwards`` commands when no -migrations are left to run / reverse. Otherwise build scripts may fail. - -0.14.8 ------- - * Improved the method signature of the ``output`` query clause (explicitly - added args, instead of using ``**kwargs``). - * Fixed a bug where ``output(as_list=True)`` would fail if no rows were found. - * Made ``piccolo migrations forwards`` command output more legible. - * Improved renamed table detection in migrations. - * Added the ``piccolo migrations clean`` command for removing orphaned rows - from the migrations table. - * Fixed a bug where ``get_migration_managers`` wasn't inclusive. - * Raising a ``ValueError`` if ``is_in`` or ``not_in`` query clauses are passed - an empty list. - * Changed the migration commands to be top level async. - * Combined ``print`` and ``sys.exit`` statements. - -0.14.7 ------- - * Added missing type annotation for ``run_sync``. - * Updating type annotations for column default values - allowing callables. - * Replaced instances of ``asyncio.run`` with ``run_sync``. - * Tidied up aiosqlite imports. - -0.14.6 ------- - * Added JSON and JSONB column types, and the arrow function for JSONB. - * Fixed a bug with the distinct clause. - * Added ``as_alias``, so select queries can override column names in the - response (i.e. SELECT foo AS bar from baz). - * Refactored JSON encoding into a separate utils file. - -0.14.5 ------- - * Removed old iPython version recommendation in the ``piccolo shell run`` and - ``piccolo playground run``, and enabled top level await. - * Fixing outstanding mypy warnings. - * Added optional requirements for the playground to setup.py - -0.14.4 ------- - * Added ``piccolo sql_shell run`` command, which launches the psql or sqlite3 - shell, using the connection parameters defined in ``piccolo_conf.py``. - This is convenient when you want to run raw SQL on your database. - * ``run_sync`` now handles more edge cases, for example if there's already - an event loop in the current thread. - * Removed asgiref dependency. - -0.14.3 ------- - * Queries can be directly awaited - ``await MyTable.select()``, as an - alternative to using the run method ``await MyTable.select().run()``. - * The ``piccolo asgi new`` command now accepts a ``name`` argument, which is - used to populate the default database name within the template. - -0.14.2 ------- - * Centralised code for importing Piccolo apps and tables - laying the - foundation for fixtures. - * Made orjson an optional dependency, installable using - ``pip install piccolo[orjson]``. - * Improved version number parsing in Postgres. - -0.14.1 ------- -Fixing a bug with dropping tables in auto migrations. - -0.14.0 ------- -Added ``Interval`` column type. - -0.13.5 ------- - * Added ``allowed_hosts`` to ``create_admin`` in ASGI template. - * Fixing bug with default ``root`` argument in some piccolo commands. - -0.13.4 ------- - * Fixed bug with ``SchemaSnapshot`` when dropping columns. - * Added custom ``__repr__`` method to ``Table``. - -0.13.3 ------- -Added ``piccolo shell run`` command for running adhoc queries using Piccolo. - -0.13.2 ------- - * Fixing bug with auto migrations when dropping columns. - * Added a ``root`` argument to ``piccolo asgi new``, ``piccolo app new`` and - ``piccolo project new`` commands, to override where the files are placed. - -0.13.1 ------- -Added support for ``group_by`` and ``Count`` for aggregate queries. - -0.13.0 ------- -Added `required` argument to ``Column``. This allows the user to indicate which -fields must be provided by the user. Other tools can use this value when -generating forms and serialisers. - -0.12.6 ------- - * Fixing a typo in ``TimestampCustom`` arguments. - * Fixing bug in ``TimestampCustom`` SQL representation. - * Added more extensive deserialisation for migrations. - -0.12.5 ------- - * Improved ``PostgresEngine`` docstring. - * Resolving rename migrations before adding columns. - * Fixed bug serialising ``TimestampCustom``. - * Fixed bug with altering column defaults to be non-static values. - * Removed ``response_handler`` from ``Alter`` query. - -0.12.4 ------- -Using orjson for JSON serialisation when using the ``output(as_json=True)`` -clause. It supports more Python types than ujson. - -0.12.3 ------- -Improved ``piccolo user create`` command - defaults the username to the current -system user. - -0.12.2 ------- -Fixing bug when sorting ``extra_definitions`` in auto migrations. - -0.12.1 ------- - * Fixed typos. - * Bumped requirements. - -0.12.0 ------- - * Added ``Date`` and ``Time`` columns. - * Improved support for column default values. - * Auto migrations can now serialise more Python types. - * Added ``Table.indexes`` method for listing table indexes. - * Auto migrations can handle adding / removing indexes. - * Improved ASGI template for FastAPI. - -0.11.8 ------- -ASGI template fix. - -0.11.7 ------- - * Improved ``UUID`` columns in SQLite - prepending 'uuid:' to the stored value - to make the type more explicit for the engine. - * Removed SQLite as an option for ``piccolo asgi new`` until auto migrations - are supported. - -0.11.6 ------- -Added support for FastAPI to ``piccolo asgi new``. - -0.11.5 ------- -Fixed bug in ``BaseMigrationManager.get_migration_modules`` - wasn't -excluding non-Python files well enough. - -0.11.4 ------- - * Stopped ``piccolo migrations new`` from creating a config.py file - was - legacy. - * Added a README file to the `piccolo_migrations` folder in the ASGI template. - -0.11.3 ------- -Fixed `__pycache__` bug when using ``piccolo asgi new``. - -0.11.2 ------- - * Showing a warning if trying auto migrations with SQLite. - * Added a command for creating a new ASGI app - ``piccolo asgi new``. - * Added a meta app for printing out the Piccolo version - - ``piccolo meta version``. - * Added example queries to the playground. - -0.11.1 ------- - * Added ``table_finder``, for use in ``AppConfig``. - * Added support for concatenating strings using an update query. - * Added more tables to the playground, with more column types. - * Improved consistency between SQLite and Postgres with ``UUID`` columns, - ``Integer`` columns, and ``exists`` queries. - -0.11.0 ------- -Added ``Numeric`` and ``Real`` column types. - -0.10.8 ------- -Fixing a bug where Postgres versions without a patch number couldn't be parsed. - -0.10.7 ------- -Improving release script. - -0.10.6 ------- -Sorting out packaging issue - old files were appearing in release. - -0.10.5 ------- -Auto migrations can now run backwards. - -0.10.4 ------- -Fixing some typos with ``Table`` imports. Showing a traceback when piccolo_conf -can't be found by ``engine_finder``. - -0.10.3 ------- -Adding missing jinja templates to setup.py. - -0.10.2 ------- -Fixing a bug when using ``piccolo project new`` in a new project. - -0.10.1 ------- -Fixing bug with enum default values. - -0.10.0 ------- -Using targ for the CLI. Refactored some core code into apps. - -0.9.3 ------ -Suppressing exceptions when trying to find the Postgres version, to avoid -an ``ImportError`` when importing `piccolo_conf.py`. - -0.9.2 ------ -``.first()`` bug fix. - -0.9.1 ------ -Auto migration fixes, and ``.first()`` method now returns None if no match is -found. - -0.9.0 ------ -Added support for auto migrations. - -0.8.3 ------ -Can use operators in update queries, and fixing 'new' migration command. - -0.8.2 ------ -Fixing release issue. - -0.8.1 ------ -Improved transaction support - can now use a context manager. Added ``Secret``, -``BigInt`` and ``SmallInt`` column types. Foreign keys can now reference the -parent table. - -0.8.0 ------ -Fixing bug when joining across several tables. Can pass values directly into -the ``Table.update`` method. Added ``if_not_exists`` option when creating a -table. - -0.7.7 ------ -Column sequencing matches the definition order. - -0.7.6 ------ -Supporting `ON DELETE` and `ON UPDATE` for foreign keys. Recording reverse -foreign key relationships. - -0.7.5 ------ -Made ``response_handler`` async. Made it easier to rename columns. - -0.7.4 ------ -Bug fixes and dependency updates. - -0.7.3 ------ -Adding missing `__int__.py` file. - -0.7.2 ------ -Changed migration import paths. - -0.7.1 ------ -Added ``remove_db_file`` method to ``SQLiteEngine`` - makes testing easier. - -0.7.0 ------ -Renamed ``create`` to ``create_table``, and can register commands via -`piccolo_conf`. - -0.6.1 ------ -Adding missing `__init__.py` files. - -0.6.0 ------ -Moved ``BaseUser``. Migration refactor. - -0.5.2 ------ -Moved drop table under ``Alter`` - to help prevent accidental drops. - -0.5.1 ------ -Added ``batch`` support. - -0.5.0 ------ -Refactored the ``Table`` Metaclass - much simpler now. Scoped more of the -attributes on ``Column`` to avoid name clashes. Added ``engine_finder`` to make -database configuration easier. - -0.4.1 ------ -SQLite is now returning datetime objects for timestamp fields. - -0.4.0 ------ -Refactored to improve code completion, along with bug fixes. - -0.3.7 ------ -Allowing ``Update`` queries in SQLite - -0.3.6 ------ -Falling back to `LIKE` instead of `ILIKE` for SQLite - -0.3.5 ------ -Renamed ``User`` to ``BaseUser``. - -0.3.4 ------ -Added ``ilike``. - -0.3.3 ------ -Added value types to columns. - -0.3.2 ------ -Default values infer the engine type. - -0.3.1 ------ -Update click version. - -0.3 ---- -Tweaked API to support more auto completion. Join support in where clause. -Basic SQLite support - mostly for playground. - -0.2 ---- -Using ``QueryString`` internally to represent queries, instead of raw strings, -to harden against SQL injection. - -0.1.2 ------ -Allowing joins across multiple tables. - -0.1.1 ------ -Added playground. diff --git a/CHANGES.rst b/CHANGES.rst new file mode 100644 index 000000000..ea841cac1 --- /dev/null +++ b/CHANGES.rst @@ -0,0 +1,4501 @@ +Changes +======= + +1.30.0 +------ + +Added support for Python 3.14 (thanks to @sinisaos for this). + +------------------------------------------------------------------------------- + +1.29.0 +------ + +* Fixed a bug with adding / subtracting ``Integer`` columns from one another + in queries (thanks to @ryanvarley for this). +* Updated the ASGI templates, and BlackSheep dependencies (thanks to @sinisaos + for this). +* Fixed a bug where decimal values generated by ``ModelBuilder`` could be too + large. +* Added an example ``M2M`` relationship in the playground to make learning + ``M2M`` easier (thanks to @sinisaos for this). +* Added documentation for ``MigrationManager.get_table_from_snapshot``, which + is a way of getting a ``Table`` from the migration history - useful when + running data migrations (thanks to @sinisaos for this). +* Columns with the ``secret=True`` argument are now added to + ``Table._meta.secret_columns`` (thanks to @sinisaos for this). +* Added documentation for the ``migration`` table. +* Tidied up Pydantic tests (thanks to @sinisaos for this). + +------------------------------------------------------------------------------- + +1.28.0 +------ + +Playground improvements +~~~~~~~~~~~~~~~~~~~~~~~ + +* Added an ``Array`` column to the playground (``Album.awards``), for easier + experimentation with array columns. +* CoachroachDB is now supported in the playground (thanks to @sinisaos for this). + + .. code-block:: bash + + piccolo playground run --engine=cockroach + +Functions +~~~~~~~~~ + +Added lots of useful array functions (thanks to @sinisaos for this). + +Here's an example, where we can easily fix a typo in an array using ``replace``: + +.. code-block:: python + + >>> await Album.update({ + ... Album.awards: Album.awards.replace('Grammy Award 2021', 'Grammy Award 2022') + ... }, force=True) + +The documentation for functions has also been improved (e.g. how to create a +custom function). + +The ``Cast`` function is now more flexible. + +``Array`` concantenation +~~~~~~~~~~~~~~~~~~~~~~~~ + +Values can be prepended: + +.. code-block:: python + + >>> await Album.update({ + ... Album.awards: ['Grammy Award 2020'] + Album.awards + ... }, force=True) + +And multiple arrays can be concatenated in one go: + +.. code-block:: python + + >>> await Album.update({ + ... Album.awards: ['Grammy Award 2020'] + Album.awards + ['Grammy Award 2025'] + ... }, force=True) + +``is_in`` and ``not_in`` sub queries +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can now use sub queries within ``is_in`` and ``not_in`` Thanks to +@sinisaos for this. + +.. code-block:: python + + >>> await Band.select().where( + ... Band.id.is_in( + ... Concert.select(Concert.band_1).where( + ... Concert.starts >= datetime.datetime(year=2025, month=1, day=1) + ... ) + ... ) + ... ) + +Other improvements +~~~~~~~~~~~~~~~~~~ + +* Auto convert a default value of ``0`` to ``0.0`` in ``Float`` columns. +* Modernised the type hints throughout the codebase (e.g. using ``list`` + instead of ``typing.List``). Thanks to @sinisaos for this. +* Fixed a bug with auto migrations, where the ``Array`` base column class + wasn't being imported. +* Improved M2M query performance by using sub selects (thanks to @sinisaos for + this). + +------------------------------------------------------------------------------- + +1.27.1 +------ + +Improved the type annotations in ``ColumnKwargs`` - made some optional. Thanks +to @stronk7 and @sinisaos for their help with this. + +------------------------------------------------------------------------------- + +1.27.0 +------ + +Improved auto completion / typo detection for column arguments. + +For example: + +.. code-block:: python + + class Band(Table): + name = Varchar(nul=True) # linters will now warn that nul is a typo (should be null) + +Thanks to @sinisaos for this. + +------------------------------------------------------------------------------- + +1.26.1 +------ + +Updated the BlackSheep ASGI template - thanks to @sinisaos for this. + +Fixed a bug with auto migrations when a ``ForeignKey`` specifies +``target_column`` - multiple primary key columns were added to the migration +file. Thanks to @waldner for reporting this issue. + +Added a tutorial for moving tables between Piccolo apps - thanks to +@sarvesh4396 for this. + +------------------------------------------------------------------------------- + +1.26.0 +------ + +Improved auto migrations - ``ON DELETE`` and ``ON UPDATE`` can be modified +on ``ForeignKey`` columns. Thanks to @sinisaos for this. + +------------------------------------------------------------------------------- + +1.25.0 +------ + +Improvements to Piccolo app creation. When running the following: + +.. code-block:: bash + + piccolo app new my_app + +It now validates that the app name (``my_app`` in this case) is valid as a +Python package. + +Also, there is now a ``--register`` flag, which automatically adds the new app +to the ``APP_REGISTRY`` in ``piccolo_conf.py``. + +.. code-block:: python + + piccolo app new my_app --register + +Other changes: + +* ``table_finder`` can now use relative modules. +* Updated the Esmerald ASGI template (thanks to @sinisaos for this). +* When using the ``remove`` method to delete a row from the database + (``await some_band.remove()``), ``some_band._exists_in_db`` is now set to + ``False``. Thanks to @sinisaos for this fix. + +------------------------------------------------------------------------------- + +1.24.2 +------ + +Fixed a bug with ``delete`` queries which had joins in the ``where`` clause. +For example: + +.. code-block:: python + + >>> await Band.delete().where(Band.manager.name == 'Guido') + +Thanks to @HakierGrzonzo for reporting the issue, and @sinisaos for the fix. + +------------------------------------------------------------------------------- + +1.24.1 +------ + +Fixed a bug with default values in ``Timestamp`` and ``Timestamptz`` columns. +Thanks to @splch for this. + +------------------------------------------------------------------------------- + +1.24.0 +------ + +* Fixed a bug with ``get_or_create`` when a table has a column with both + ``null=False`` and ``default=None`` - thanks to @bymoye for reporting this + issue. +* If a ``PostgresEngine`` uses the ``dsn`` argument for ``asyncpg``, this is + now used by ``piccolo sql_shell run``. Thanks to @abhishek-compro for + suggesting this. +* Fixed the type annotation for the ``length`` argument of ``Varchar`` - it + is allowed to be ``None``. Thanks to @Compro-Prasad for this. + +------------------------------------------------------------------------------- + +1.23.0 +------ + +* Added Quart, Sanic, and Falcon as supported ASGI frameworks (thanks to + @sinisaos for this). +* Fixed a bug with very large integers in SQLite. +* Fixed type annotation for ``Timestamptz`` default values (thanks to @Skelmis + for this). + +------------------------------------------------------------------------------- + +1.22.0 +------ + +Python 3.13 is now officially supported. + +``JSON`` / ``JSONB`` querying has been significantly improved. For example, if +we have this table: + +.. code-block:: python + + class RecordingStudio(Table): + facilities = JSONB() + +And the ``facilities`` column contains the following JSON data: + +.. code-block:: python + + { + "technicians": [ + {"name": "Alice Jones"}, + {"name": "Bob Williams"}, + ] + } + +We can get the first technician name as follows: + +.. code-block:: python + + >>> await RecordingStudio.select( + ... RecordingStudio.facilities["technicians"][0]["name"].as_alias("name") + ... ).output(load_json=True) + [{'name': 'Alice Jones'}, ...] + +``TableStorage`` (used for dynamically creating Piccolo ``Table`` classes from +an existing database) was improved, to support a Dockerised version of Piccolo +Admin, which is coming soon. + +------------------------------------------------------------------------------- + +1.21.0 +------ + +Postgres 17 is now officially supported. + +Fixed a bug with joins, when a ``ForeignKey`` column had ``db_column_name`` +specified. Thanks to @jessemcl-flwls for reporting this issue. + +------------------------------------------------------------------------------- + +1.20.0 +------ + +``get_related`` now works multiple layers deep: + +.. code-block:: python + + concert = await Concert.objects().first() + manager = await concert.get_related(Concert.band_1._.manager) + +------------------------------------------------------------------------------- + +1.19.1 +------ + +Fixed a bug with the ``get_m2m`` method, which would raise a ``ValueError`` +when no objects were found. It now handles this gracefully and returns an empty +list instead. Thanks to @nVitius for this fix. + +Improved the ASGI templates (including a fix for the latest Litestar version). +Thanks to @sinisaos for this. + +------------------------------------------------------------------------------- + +1.19.0 +------ + +Added support for row locking (i.e. ``SELECT ... FOR UPDATE``). + +For example, if we have this table: + +.. code-block:: python + + class Concert(Table): + name = Varchar() + tickets_available = Integer() + +And we want to make sure that ``tickets_available`` never goes below 0, we can +do the following: + +.. code-block:: python + + async def book_tickets(ticket_count: int): + async with Concert._meta.db.transaction(): + concert = await Concert.objects().where( + Concert.name == "Awesome Concert" + ).first().lock_rows() + + if concert.tickets_available >= ticket_count: + await concert.update_self({ + Concert.tickets_available: Concert.tickets_available - ticket_count + }) + else: + raise ValueError("Not enough tickets are available!") + +This means that when multiple transactions are running at the same time, it +isn't possible to book more tickets than are available. + +Thanks to @dkopitsa for adding this feature. + +------------------------------------------------------------------------------- + +1.18.0 +------ + +``update_self`` +~~~~~~~~~~~~~~~ + +Added the ``update_self`` method, which is an alternative to the ``save`` +method. Here's an example where it's useful: + +.. code-block:: python + + # If we have a band object: + >>> band = await Band.objects().get(name="Pythonistas") + >>> band.popularity + 1000 + + # We can increment the popularity, based on the current value in the + # database: + >>> await band.update_self({ + ... Band.popularity: Band.popularity + 1 + ... }) + + # The new value is set on the object: + >>> band.popularity + 1001 + + # It's safer than using the `save` method, because the popularity value on + # the object might be out of date with what's in the database: + band.popularity += 1 + await band.save() + +Thanks to @trondhindenes for suggesting this feature. + +Batch raw queries +~~~~~~~~~~~~~~~~~ + +The ``batch`` method can now be used with ``raw`` queries. For example: + +.. code-block:: python + + async with await MyTable.raw("SELECT * FROM my_table").batch() as batch: + async for _batch in batch: + print(_batch) + +This is useful when you expect a raw query to return a lot of data. + +Thanks to @devsarvesh92 for suggesting this feature. + +------------------------------------------------------------------------------- + +1.17.1 +------ + +Fixed a bug with migrations, where altering a column type from ``Integer`` to +``Float`` could fail. Thanks to @kurtportelli for reporting this issue. + +------------------------------------------------------------------------------- + +1.17.0 +------ + +Each migration is automatically wrapped in a transaction - this can now be +disabled using the ``wrap_in_transaction`` argument: + +.. code-block:: python + + manager = MigrationManager( + wrap_in_transaction=False, + ... + ) + +This is useful when writing a manual migration, and you want to manage all of +the transaction logic yourself (or want multiple transactions). + +``granian`` is now a supported server in the ASGI templates. Thanks to +@sinisaos for this. + +------------------------------------------------------------------------------- + +1.16.0 +------ + +Added custom async ``TestCase`` subclasses, to help with testing. + +For example ``AsyncTransactionTest``, which wraps each test in a transaction +automatically: + +.. code-block:: python + + class TestBandEndpoint(AsyncTransactionTest): + + async def test_band_response(self): + """ + Make sure the endpoint returns a 200. + """ + # This data automatically gets removed from the database when the + # test finishes: + band = Band({Band.name: "Pythonistas"}) + await band.save() + + # Using an API testing client, like httpx: + response = await client.get(f"/bands/{band.id}/") + self.assertEqual(response.status_code, 200) + +And ``AsyncTableTest``, which automatically creates and drops tables: + +.. code-block:: python + + class TestBand(AsyncTableTest): + + # These tables automatically get created and dropped: + tables = [Band] + + async def test_band(self): + ... + +------------------------------------------------------------------------------- + +1.15.0 +------ + +Improved ``refresh`` - it now works with prefetched objects. For example: + +.. code-block:: python + + >>> band = await Band.objects(Band.manager).first() + >>> band.manager.name + "Guido" + + # If the manager has changed in the database, when we refresh the band, the + # manager object will also be updated: + >>> await band.refresh() + >>> band.manager.name + "New name" + +Also, improved the error messages when creating a ``BaseUser`` - thanks to +@haaavk for this. + +------------------------------------------------------------------------------- + +1.14.0 +------ + +Laying the foundations for alterative Postgres database drivers (e.g. +``psqlpy``). Thanks to @insani7y and @chandr-andr for their help with this. + +.. warning:: + The SQL generated by Piccolo changed slightly in this release. Aliases used + to be like ``"manager$name"`` but now they are like ``"manager.name"`` + (note ``$`` changed to ``.``). If you are using ``SelectRaw`` in your queries + to refer to these columns, then they will need updating. Please let us know + if you encounter any other issues. + +------------------------------------------------------------------------------- + +1.13.1 +------ + +In Piccolo ``1.6.0`` we moved some aggregate functions to a new file. We now +re-export them from their original location to keep backwards compatibility. +Thanks to @sarvesh-deserve for reporting this issue. + +------------------------------------------------------------------------------- + +1.13.0 +------ + +Improved ``LazyTableReference``, to help prevent circular import errors. + +------------------------------------------------------------------------------- + +1.12.0 +------ + +* Added documentation for one to one fields. +* Upgraded ASGI templates (thanks to @sinisaos for this). +* Migrations can now be hardcoded as fake. +* Refactored tests to reduce boilerplate code. +* Updated documentation dependencies. + +------------------------------------------------------------------------------- + +1.11.0 +------ + +Added datetime functions, for example ``Year``: + +.. code-block:: python + + >>> from piccolo.query.functions import Year + >>> await Concert.select(Year(Concert.starts, alias="starts_year")) + [{'starts_year': 2024}] + +Added the ``Concat`` function, for concatenating strings: + +.. code-block:: python + + >>> from piccolo.query.functions import Concat + >>> await Band.select( + ... Concat( + ... Band.name, + ... '-', + ... Band.manager._.name, + ... alias="name_and_manager" + ... ) + ... ) + [{"name_and_manager": "Pythonistas-Guido"}] + +------------------------------------------------------------------------------- + +1.10.0 +------ + +Added ``not_any`` method for ``Array`` columns. This will return rows where an +array doesn't contain the given value. For example: + +.. code-block:: python + + class MyTable(Table): + array_column = Array(Integer()) + + >>> await MyTable.select( + ... MyTable.array_column + ... ).where( + ... MyTable.array_column.not_any(1) + ... ) + [{"array_column": [4, 5, 6]}] + +Also fixed a bunch of Pylance linter warnings across the codebase. + +------------------------------------------------------------------------------- + +1.9.0 +----- + +Added some math functions, for example ``Abs``, ``Ceil``, ``Floor`` and +``Round``. + +.. code-block:: python + + >>> from piccolo.query.functions import Round + >>> await Ticket.select(Round(Ticket.price, alias="price")) + [{'price': 50.0}] + +Added more operators to ``QueryString`` (multiply, divide, modulus, power), so +we can do things like: + +.. code-block:: python + + >>> await Ticket.select(Round(Ticket.price) * 2) + [{'price': 100.0}] + +Fixed some edge cases around defaults for ``Array`` columns. + +.. code-block:: python + + def get_default(): + # This used to fail: + return [datetime.time(hour=8, minute=0)] + + class MyTable(Table): + times = Array(Time(), default=get_default) + +Fixed some deprecation warnings, and improved CockroachDB array tests. + +------------------------------------------------------------------------------- + +1.8.0 +----- + +Added the ``Cast`` function, for performing type conversion. + +Here's an example, where we convert a ``timestamp`` to ``time``: + +.. code-block:: python + + >>> from piccolo.columns import Time + >>> from piccolo.query.functions import Cast + + >>> await Concert.select(Cast(Concert.starts, Time())) + [{'starts': datetime.time(19, 0)}] + +A new section was also added to the docs describing functions in more detail. + +------------------------------------------------------------------------------- + +1.7.0 +----- + +Arrays of ``Date`` / ``Time`` / ``Timestamp`` / ``Timestamptz`` now work in +SQLite. + +For example: + +.. code-block:: python + + class MyTable(Table): + times = Array(Time()) + dates = Array(Date()) + timestamps = Array(Timestamp()) + timestamps_tz = Array(Timestamptz()) + +------------------------------------------------------------------------------- + +1.6.0 +----- + +Added support for a bunch of Postgres functions, like ``Upper``, ``Lower``, +``Length``, and ``Ltrim``. They can be used in ``select`` queries: + +.. code-block:: python + + from piccolo.query.functions.string import Upper + >>> await Band.select(Upper(Band.name, alias="name")) + [{"name": "PYTHONISTAS"}] + +And also in ``where`` clauses: + +.. code-block:: python + + >>> await Band.select().where(Upper(Band.manager.name) == 'GUIDO') + [{"name": "Pythonistas"}] + +------------------------------------------------------------------------------- + +1.5.2 +----- + +Added an ``Album`` table to the playground, along with some other +improvements. + +Fixed a bug with the ``output(load_json=True)`` clause, when used on joined +tables. + +------------------------------------------------------------------------------- + +1.5.1 +----- + +Fixed a bug with the CLI when reversing migrations (thanks to @metakot for +reporting this). + +Updated the ASGI templates (thanks to @tarsil for adding Lilya). + +------------------------------------------------------------------------------- + +1.5.0 +----- + +Lots of internal improvements, mostly to support new functionality in Piccolo +Admin. + +------------------------------------------------------------------------------- + +1.4.2 +----- + +Improved how ``ModelBuilder`` handles recursive foreign keys. + +------------------------------------------------------------------------------- + +1.4.1 +----- + +Fixed an edge case with auto migrations. + +If starting from a table like this, with a custom primary key column: + +.. code-block:: python + + class MyTable(Table): + id = UUID(primary_key=True) + +When a foreign key is added to the table which references itself: + +.. code-block:: python + + class MyTable(Table): + id = UUID(primary_key=True) + fk = ForeignKey("self") + +The auto migrations could fail in some situations. + +------------------------------------------------------------------------------- + +1.4.0 +----- + +Improved how ``create_pydantic_model`` handles ``Array`` columns: + +* Multidimensional arrays (e.g. ``Array(Array(Integer))``) have more accurate + types. +* ``Array(Email())`` now validates that each item in the list is an email + address. +* ``Array(Varchar(length=10))`` now validates that each item is the correct + length (i.e. 10 in this example). + +Other changes +~~~~~~~~~~~~~ + +Some Pylance errors were fixed in the codebase. + +------------------------------------------------------------------------------- + +1.3.2 +----- + +Fixed a bug with nested array columns containing ``BigInt``. For example: + +.. code-block:: python + + class MyTable(Table): + my_column = Array(Array(BigInt)) + +Thanks to @AmazingAkai for reporting this issue. + +------------------------------------------------------------------------------- + +1.3.1 +----- + +Fixed a bug with foreign keys which reference ``BigSerial`` primary keys. +Thanks to @Abdelhadi92 for reporting this issue. + +------------------------------------------------------------------------------- + +1.3.0 +----- + +Added the ``piccolo user list`` command - a quick and convenient way of listing +Piccolo Admin users from the command line. + +``ModelBuilder`` now creates timezone aware ``datetime`` objects for +``Timestamptz`` columns. + +Updated the ASGI templates. + +SQLite auto migrations are now allowed. We used to raise an exception, but +now we output a warning instead. While SQLite auto migrations aren't as feature +rich as Postgres, they work fine for simple use cases. + +------------------------------------------------------------------------------- + +1.2.0 +----- + +There's now an alternative syntax for joins, which works really well with +static type checkers like Mypy and Pylance. + +The traditional syntax (which continues to work as before): + +.. code-block:: python + + # Get the band name, and the manager's name from a related table + await Band.select(Band.name, Band.manager.name) + +The alternative syntax is as follows: + +.. code-block:: python + + await Band.select(Band.name, Band.manager._.name) + +Note how we use ``._.`` instead of ``.`` after a ``ForeignKey``. + +This offers a considerably better static typing experience. In the above +example, type checkers know that ``Band.manager._.name`` refers to the ``name`` +column on the ``Manager`` table. This means typos can be detected, and code +navigation is easier. + +Other changes +~~~~~~~~~~~~~ + +* Improve static typing for ``get_related``. +* Added support for the ``esmerald`` ASGI framework. + +------------------------------------------------------------------------------- + +1.1.1 +----- + +Piccolo allows the user to specify savepoint names which are used in +transactions. For example: + +.. code-block:: python + + async with DB.transaction() as transaction: + await Band.insert(Band(name='Pythonistas')) + + # Passing in a savepoint name is optional: + savepoint_1 = await transaction.savepoint('savepoint_1') + + await Band.insert(Band(name='Terrible band')) + + # Oops, I made a mistake! + await savepoint_1.rollback_to() + +Postgres doesn't allow us to parameterise savepoint names, which means there's +a small chance of SQL injection, if for some reason the savepoint names were +generated from end-user input. Even though the likelihood is very low, it's +best to be safe. We now validate the savepoint name, to make sure it can only +contain certain safe characters. Thanks to @Skelmis for making this change. + +------------------------------------------------------------------------------- + +1.1.0 +----- + +Added support for Python 3.12. + +Modified ``create_pydantic_model``, so additional information is returned in +the JSON schema to distinguish between ``Timestamp`` and ``Timestamptz`` +columns. This will be used for future Piccolo Admin enhancements. + +------------------------------------------------------------------------------- + +1.0.0 +----- + +Piccolo v1 is now available! + +We migrated to Pydantic v2, and also migrated Piccolo Admin to Vue 3, which +puts the project in a good place moving forward. + +We don't anticipate any major issues for people who are upgrading. If you +encounter any bugs let us know. + +Make sure you have v1 of Piccolo, Piccolo API, and Piccolo Admin. + +------------------------------------------------------------------------------- + +1.0a3 +----- + +Namespaced all custom values we added to Pydantic's JSON schema for easier +maintenance. + +------------------------------------------------------------------------------- + +1.0a2 +----- + +All of the changes from 0.120.0 merged into the v1 branch. + +------------------------------------------------------------------------------- + +0.121.0 +------- + +Modified the ``BaseUser.login`` logic so all code paths take the same time. +Thanks to @Skelmis for this. + +------------------------------------------------------------------------------- + +0.120.0 +------- + +Improved how ``ModelBuilder`` generates JSON data. + +The number of password hash iterations used in ``BaseUser`` has been increased +to keep pace with the latest guidance from OWASP - thanks to @Skelmis for this. + +Fixed a bug with auto migrations when the table is in a schema. + +------------------------------------------------------------------------------- + +1.0a1 +----- + +Initial alpha release of Piccolo v1, with Pydantic v2 support. + +------------------------------------------------------------------------------- + +0.119.0 +------- + +``ModelBuilder`` now works with ``LazyTableReference`` (which is used when we +have circular references caused by a ``ForeignKey``). + +With this table: + +.. code-block:: python + + class Band(Table): + manager = ForeignKey( + LazyTableReference( + 'Manager', + module_path='some.other.folder.tables' + ) + ) + +We can now create a dynamic test fixture: + +.. code-block:: python + + my_model = await ModelBuilder.build(Band) + +------------------------------------------------------------------------------- + +0.118.0 +------- + +If you have lots of Piccolo apps, you can now create auto migrations for them +all in one go: + +.. code-block:: bash + + piccolo migrations new all --auto + +Thanks to @hoosnick for suggesting this new feature. + +The documentation for running migrations has also been improved, as well as +improvements to the sorting of migrations based on their dependencies. + +Support for Python 3.7 was dropped in this release as it's now end of life. + +------------------------------------------------------------------------------- + +0.117.0 +------- + +Version pinning Pydantic to v1, as v2 has breaking changes. + +We will add support for Pydantic v2 in a future release. + +Thanks to @sinisaos for helping with this. + +------------------------------------------------------------------------------- + +0.116.0 +------- + +Fixture formatting +~~~~~~~~~~~~~~~~~~ + +When creating a fixture: + +.. code-block:: bash + + piccolo fixtures dump + +The JSON output is now nicely formatted, which is useful because we can pipe +it straight to a file, and commit it to Git without having to manually run a +formatter on it. + +.. code-block:: bash + + piccolo fixtures dump > my_fixture.json + +Thanks to @sinisaos for this. + +Protected table names +~~~~~~~~~~~~~~~~~~~~~ + +We used to raise a ``ValueError`` if a table was called ``user``. + +.. code-block:: python + + class User(Table): # ValueError! + ... + +It's because ``user`` is already used by Postgres (e.g. try ``SELECT user`` or +``SELECT * FROM user``). + +We now emit a warning instead for these reasons: + +* Piccolo wraps table names in quotes to avoid clashes with reserved keywords. +* Sometimes you're stuck with a table name from a pre-existing schema, and + can't easily rename it. + +Re-export ``WhereRaw`` +~~~~~~~~~~~~~~~~~~~~~~ + +If you want to write raw SQL in your where queries you use ``WhereRaw``: + +.. code-block:: python + + >>> Band.select().where(WhereRaw('TRIM(name) = {}', 'Pythonistas')) + +You can now import it from ``piccolo.query`` to be consistent with +``SelectRaw`` and ``OrderByRaw``. + +.. code-block:: python + + from piccolo.query import WhereRaw + +------------------------------------------------------------------------------- + +0.115.0 +------- + +Fixture upserting +~~~~~~~~~~~~~~~~~ + +Fixtures can now be upserted. For example: + +.. code-block:: bash + + piccolo fixtures load my_fixture.json --on_conflict='DO UPDATE' + +The options are: + +* ``DO NOTHING``, meaning any rows with a matching primary key will be left + alone. +* ``DO UPDATE``, meaning any rows with a matching primary key will be updated. + +This is really useful, as you can now edit fixtures and load them multiple +times without getting foreign key constraint errors. + +Schema fixes +~~~~~~~~~~~~ + +We recently added support for schemas, for example: + +.. code-block:: python + + class Band(Table, schema='music'): + ... + +This release contains: + +* A fix for migrations when changing a table's schema back to 'public' (thanks to + @sinisaos for discovering this). +* A fix for ``M2M`` queries, when the tables are in a schema other than + 'public' (thanks to @quinnalfaro for reporting this). + +Added ``distinct`` method to ``count`` queries +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +We recently added support for ``COUNT DISTINCT`` queries. The syntax is: + +.. code-block:: python + + await Concert.count(distinct=[Concert.start_date]) + +The following alternative syntax now also works (just to be consistent with +other queries like ``select``): + +.. code-block:: python + + await Concert.count().distinct([Concert.start_date]) + +------------------------------------------------------------------------------- + +0.114.0 +------- + +``count`` queries can now return the number of distinct rows. For example, if +we have this table: + +.. code-block:: python + + class Concert(Table): + band = Varchar() + start_date = Date() + +With this data: + +.. table:: + :widths: auto + + =========== ========== + band start_date + =========== ========== + Pythonistas 2023-01-01 + Pythonistas 2023-02-03 + Rustaceans 2023-01-01 + =========== ========== + +We can easily get the number of unique concert dates: + +.. code-block:: python + + >>> await Concert.count(distinct=[Concert.start_date]) + 2 + +We could have just done this instead: + +.. code-block:: python + + len(await Concert.select(Concert.start_date).distinct()) + +But it's far less efficient when you have lots of rows, because all of the +distinct rows need to be returned from the database. + +Also, the docs for the ``count`` query, aggregate functions, and +``group_by`` clause were significantly improved. + +Many thanks to @lqmanh and @sinisaos for their help with this. + +------------------------------------------------------------------------------- + +0.113.0 +------- + +If Piccolo detects a renamed table in an auto migration, it asks the user for +confirmation. When lots of tables have been renamed, Piccolo is now more +intelligent about when to ask for confirmation. Thanks to @sumitsharansatsangi +for suggesting this change, and @sinisaos for reviewing. + +Also, fixed the type annotations for ``MigrationManager.add_table``. + +------------------------------------------------------------------------------- + +0.112.1 +------- + +Fixed a bug with serialising table classes in migrations. + +------------------------------------------------------------------------------- + +0.112.0 +------- + +Added support for schemas in Postgres and CockroachDB. + +For example: + +.. code-block:: python + + class Band(Table, schema="music"): + ... + +When creating the table, the schema will be created automatically if it doesn't +already exist. + +.. code-block:: python + + await Band.create_table() + +It also works with migrations. If we change the ``schema`` value for the table, +Piccolo will detect this, and create a migration for moving it to the new schema. + +.. code-block:: python + + class Band(Table, schema="music_2"): + ... + + # Piccolo will detect that the table needs to be moved to a new schema. + >>> piccolo migrations new my_app --auto + +------------------------------------------------------------------------------- + +0.111.1 +------- + +Fixing a bug with ``ModelBuilder`` and ``Decimal`` / ``Numeric`` columns. + +------------------------------------------------------------------------------- + +0.111.0 +------- + +Added the ``on_conflict`` clause for ``insert`` queries. This enables **upserts**. + +For example, here we insert some bands, and if they already exist then do +nothing: + +.. code-block:: python + + await Band.insert( + Band(name='Pythonistas'), + Band(name='Rustaceans'), + Band(name='C-Sharps'), + ).on_conflict(action='DO NOTHING') + +Here we insert some albums, and if they already exist then we update the price: + +.. code-block:: python + + await Album.insert( + Album(title='OK Computer', price=10.49), + Album(title='Kid A', price=9.99), + Album(title='The Bends', price=9.49), + ).on_conflict( + action='DO UPDATE', + target=Album.title, + values=[Album.price] + ) + +Thanks to @sinisaos for helping with this. + +------------------------------------------------------------------------------- + +0.110.0 +------- + +ASGI frameworks +~~~~~~~~~~~~~~~ + +The ASGI frameworks in ``piccolo asgi new`` have been updated. ``starlite`` has +been renamed to ``litestar``. Thanks to @sinisaos for this. + +ModelBuilder +~~~~~~~~~~~~ + +Generic types are now used in ``ModelBuilder``. + +.. code-block:: python + + # mypy knows this is a `Band` instance: + band = await ModelBuilder.build(Band) + +``DISTINCT ON`` +~~~~~~~~~~~~~~~ + +Added support for ``DISTINCT ON`` queries. For example, here we fetch the most +recent album for each band: + +.. code-block:: python + + >>> await Album.select().distinct( + ... on=[Album.band] + ... ).order_by( + ... Album.band + ... ).order_by( + ... Album.release_date, + ... ascending=False + ... ) + +Thanks to @sinisaos and @williamflaherty for their help with this. + +------------------------------------------------------------------------------- + +0.109.0 +------- + +Joins are now possible without foreign keys using ``join_on``. + +For example: + +.. code-block:: python + + class Manager(Table): + name = Varchar(unique=True) + email = Varchar() + + class Band(Table): + name = Varchar() + manager_name = Varchar() + + >>> await Band.select( + ... Band.name, + ... Band.manager_name.join_on(Manager.name).email + ... ) + +------------------------------------------------------------------------------- + +0.108.0 +------- + +Added support for savepoints within transactions. + +.. code-block:: python + + async with DB.transaction() as transaction: + await Manager.objects().create(name="Great manager") + savepoint = await transaction.savepoint() + await Manager.objects().create(name="Great manager") + await savepoint.rollback_to() + # Only the first manager will be inserted. + +The behaviour of nested context managers has also been changed slightly. + +.. code-block:: python + + async with DB.transaction(): + async with DB.transaction(): + # This used to raise an exception + +We no longer raise an exception if there are nested transaction context +managers, instead the inner ones do nothing. + +If you want the existing behaviour: + +.. code-block:: python + + async with DB.transaction(): + async with DB.transactiona(allow_nested=False): + # TransactionError! + +------------------------------------------------------------------------------- + +0.107.0 +------- + +Added the ``log_responses`` option to the database engines. This makes the +engine print out the raw response from the database for each query, which +is useful during debugging. + +.. code-block:: python + + # piccolo_conf.py + + DB = PostgresEngine( + config={'database': 'my_database'}, + log_queries=True, + log_responses=True + ) + +We also updated the Starlite ASGI template - it now uses the new import paths +(thanks to @sinisaos for this). + +------------------------------------------------------------------------------- + +0.106.0 +------- + +Joins now work within ``update`` queries. For example: + +.. code-block:: python + + await Band.update({ + Band.name: 'Amazing Band' + }).where( + Band.manager.name == 'Guido' + ) + +Other changes: + +* Improved the template used by ``piccolo app new`` when creating a new + Piccolo app (it now uses ``table_finder``). + +------------------------------------------------------------------------------- + +0.105.0 +------- + +Improved the performance of select queries with complex joins. Many thanks to +@powellnorma and @sinisaos for their help with this. + +------------------------------------------------------------------------------- + +0.104.0 +------- + +Major improvements to Piccolo's typing / auto completion support. + +For example: + +.. code-block:: python + + >>> bands = await Band.objects() # List[Band] + + >>> band = await Band.objects().first() # Optional[Band] + + >>> bands = await Band.select().output(as_json=True) # str + +------------------------------------------------------------------------------- + +0.103.0 +------- + +``SelectRaw`` +~~~~~~~~~~~~~ + +This allows you to access features in the database which aren't exposed +directly by Piccolo. For example, Postgres functions: + +.. code-block:: python + + from piccolo.query import SelectRaw + + >>> await Band.select( + ... Band.name, + ... SelectRaw("log(popularity) AS log_popularity") + ... ) + [{'name': 'Pythonistas', 'log_popularity': 3.0}] + +Large fixtures +~~~~~~~~~~~~~~ + +Piccolo can now load large fixtures using ``piccolo fixtures load``. The +rows are inserted in batches, so the database adapter doesn't raise any errors. + +------------------------------------------------------------------------------- + +0.102.0 +------- + +Migration file names +~~~~~~~~~~~~~~~~~~~~ + +The naming convention for migrations has changed slightly. It used to be just +a timestamp - for example: + +.. code-block:: text + + 2021-09-06T13-58-23-024723.py + +By convention Python files should start with a letter, and only contain +``a-z``, ``0-9`` and ``_``, so the new format is: + +.. code-block:: text + + my_app_2021_09_06T13_58_23_024723.py + +.. note:: You can name a migration file anything you want (it's the ``ID`` + value inside it which is important), so this change doesn't break anything. + +Enhanced Pydantic configuration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +We now expose all of Pydantic's configuration options to +``create_pydantic_model``: + +.. code-block:: python + + class MyPydanticConfig(pydantic.BaseConfig): + extra = 'forbid' + + model = create_pydantic_model( + table=MyTable, + pydantic_config_class=MyPydanticConfig + ) + +Thanks to @waldner for this. + +Other changes +~~~~~~~~~~~~~ + +* Fixed a bug with ``get_or_create`` and null columns (thanks to @powellnorma + for reporting this issue). +* Updated the Starlite ASGI template, so it uses the latest syntax for mounting + Piccolo Admin (thanks to @sinisaos for this, and the Starlite team). + +------------------------------------------------------------------------------- + +0.101.0 +------- + +``piccolo fixtures load`` is now more intelligent about how it loads data, to +avoid foreign key constraint errors. + +------------------------------------------------------------------------------- + +0.100.0 +------- + +``Array`` columns now support choices. + +.. code-block:: python + + class Ticket(Table): + class Extras(str, enum.Enum): + drink = "drink" + snack = "snack" + program = "program" + + extras = Array(Varchar(), choices=Extras) + +We can then use the ``Enum`` in our queries: + +.. code-block:: python + + >>> await Ticket.insert( + ... Ticket(extras=[Extras.drink, Extras.snack]), + ... Ticket(extras=[Extras.program]), + ... ) + +This will also be supported in Piccolo Admin in the next release. + +------------------------------------------------------------------------------- + +0.99.0 +------ + +You can now use the ``returning`` clause with ``delete`` queries. + +For example: + +.. code-block:: python + + >>> await Band.delete().where(Band.popularity < 100).returning(Band.name) + [{'name': 'Terrible Band'}, {'name': 'Awful Band'}] + +This also means you can count the number of deleted rows: + +.. code-block:: python + + >>> len(await Band.delete().where(Band.popularity < 100).returning(Band.id)) + 2 + +Thanks to @waldner for adding this feature. + +------------------------------------------------------------------------------- + +0.98.0 +------ + +SQLite ``TransactionType`` +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can now specify the transaction type for SQLite. + +This is useful when using SQLite in production, as it's possible to get +``database locked`` errors if you're running lots of transactions concurrently, +and don't use the correct transaction type. + +In this example we use an ``IMMEDIATE`` transaction: + +.. code-block:: python + + from piccolo.engine.sqlite import TransactionType + + async with Band._meta.db.transaction( + transaction_type=TransactionType.immediate + ): + band = await Band.objects().get_or_create(Band.name == 'Pythonistas') + ... + +We've added a `new tutorial `_ +which explains this in more detail, as well as other tips for using asyncio and +SQLite together effectively. + +Thanks to @powellnorma and @sinisaos for their help with this. + +Other changes +~~~~~~~~~~~~~ + +* Fixed a bug with camelCase column names (we recommend using snake_case, but + sometimes it's unavoidable when using Piccolo with an existing schema). + Thanks to @sinisaos for this. +* Fixed a typo in the docs with ``raw`` queries - thanks to @StitiFatah for + this. + +------------------------------------------------------------------------------- + +0.97.0 +------ + +Some big improvements to ``order_by`` clauses. + +It's now possible to combine ascending and descending: + +.. code-block:: python + + await Band.select( + Band.name, + Band.popularity + ).order_by( + Band.name + ).order_by( + Band.popularity, + ascending=False + ) + +You can also order by anything you want using ``OrderByRaw``: + +.. code-block:: python + + from piccolo.query import OrderByRaw + + await Band.select( + Band.name + ).order_by( + OrderByRaw('random()') + ) + +------------------------------------------------------------------------------- + +0.96.0 +------ + +Added the ``auto_update`` argument to ``Column``. Its main use case is columns +like ``modified_on`` where we want the value to be updated automatically each +time the row is saved. + +.. code-block:: python + + class Band(Table): + name = Varchar() + popularity = Integer() + modified_on = Timestamp( + null=True, + default=None, + auto_update=datetime.datetime.now + ) + + # The `modified_on` column will automatically be updated to the current + # timestamp: + >>> await Band.update({ + ... Band.popularity: Band.popularity + 100 + ... }).where( + ... Band.name == 'Pythonistas' + ... ) + +It works with ``MyTable.update`` and also when using the ``save`` method on +an existing row. + +------------------------------------------------------------------------------- + +0.95.0 +------ + +Made improvements to the Piccolo playground. + +* Syntax highlighting is now enabled. +* The example queries are now async (iPython supports top level await, so + this works fine). +* You can optionally use your own iPython configuration + ``piccolo playground run --ipython_profile`` (for example if you want a + specific colour scheme, rather than the one we use by default). + +Thanks to @haffi96 for this. See `PR 656 `_. + +------------------------------------------------------------------------------- + +0.94.0 +------ + +Fixed a bug with ``MyTable.objects().create()`` and columns which are not +nullable. Thanks to @metakot for reporting this issue. + +We used to use ``logging.getLogger(__file__)``, but as @Drapersniper pointed +out, the Python docs recommend ``logging.getLogger(__name__)``, so it has been +changed. + +------------------------------------------------------------------------------- + +0.93.0 +------ + +* Fixed a bug with nullable ``JSON`` / ``JSONB`` columns and + ``create_pydantic_model`` - thanks to @eneacosta for this fix. +* Made the ``Time`` column type importable from ``piccolo.columns``. +* Python 3.11 is now supported. +* Postgres 9.6 is no longer officially supported, as it's end of life, but + Piccolo should continue to work with it just fine for now. +* Improved docs for transactions, added docs for the ``as_of`` clause in + CockroachDB (thanks to @gnat for this), and added docs for + ``add_raw_backwards``. + +------------------------------------------------------------------------------- + +0.92.0 +------ + +Added initial support for Cockroachdb (thanks to @gnat for this massive +contribution). + +Fixed Pylance warnings (thanks to @MiguelGuthridge for this). + +------------------------------------------------------------------------------- + +0.91.0 +------ + +Added support for Starlite. If you use ``piccolo asgi new`` you'll see it as +an option for a router. + +Thanks to @sinisaos for adding this, and @peterschutt for helping debug ASGI +mounting. + +------------------------------------------------------------------------------- + +0.90.0 +------ + +Fixed an edge case, where a migration could fail if: + +* 5 or more tables were being created at once. +* They all contained foreign keys to each other, as shown below. + +.. code-block:: python + + class TableA(Table): + pass + + class TableB(Table): + fk = ForeignKey(TableA) + + class TableC(Table): + fk = ForeignKey(TableB) + + class TableD(Table): + fk = ForeignKey(TableC) + + class TableE(Table): + fk = ForeignKey(TableD) + + +Thanks to @sumitsharansatsangi for reporting this issue. + +------------------------------------------------------------------------------- + +0.89.0 +------ + +Made it easier to access the ``Email`` columns on a table. + +.. code-block:: python + + >>> MyTable._meta.email_columns + [MyTable.email_column_1, MyTable.email_column_2] + +This was added for Piccolo Admin. + +------------------------------------------------------------------------------- + +0.88.0 +------ + +Fixed a bug with migrations - when using ``db_column_name`` it wasn't being +used in some alter statements. Thanks to @theelderbeever for reporting this +issue. + +.. code-block:: python + + class Concert(Table): + # We use `db_column_name` when the column name is problematic - e.g. if + # it clashes with a Python keyword. + in_ = Varchar(db_column_name='in') + +------------------------------------------------------------------------------- + +0.87.0 +------ + +When using ``get_or_create`` with ``prefetch`` the behaviour was inconsistent - +it worked as expected when the row already existed, but prefetch wasn't working +if the row was being created. This now works as expected: + +.. code-block:: python + + >>> band = Band.objects(Band.manager).get_or_create( + ... (Band.name == "New Band 2") & (Band.manager == 1) + ... ) + + >>> band.manager + + >>> band.manager.name + "Mr Manager" + +Thanks to @backwardspy for reporting this issue. + +------------------------------------------------------------------------------- + +0.86.0 +------ + +Added the ``Email`` column type. It's basically identical to ``Varchar``, +except that when we use ``create_pydantic_model`` we add email validation +to the generated Pydantic model. + +.. code-block:: python + + from piccolo.columns.column_types import Email + from piccolo.table import Table + from piccolo.utils.pydantic import create_pydantic_model + + + class MyTable(Table): + email = Email() + + + model = create_pydantic_model(MyTable) + + model(email="not a valid email") + # ValidationError! + +Thanks to @sinisaos for implementing this feature. + +------------------------------------------------------------------------------- + +0.85.1 +------ + +Fixed a bug with migrations - when run backwards, ``raw`` was being called +instead of ``raw_backwards``. Thanks to @translunar for the fix. + +------------------------------------------------------------------------------- + +0.85.0 +------ + +You can now append items to an array in an update query: + +.. code-block:: python + + await Ticket.update({ + Ticket.seat_numbers: Ticket.seat_numbers + [1000] + }).where(Ticket.id == 1) + +Currently Postgres only. Thanks to @sumitsharansatsangi for suggesting this +feature. + +------------------------------------------------------------------------------- + +0.84.0 +------ + +You can now preview the DDL statements which will be run by Piccolo migrations. + +.. code-block:: bash + + piccolo migrations forwards my_app --preview + +Thanks to @AliSayyah for adding this feature. + +------------------------------------------------------------------------------- + +0.83.0 +------ + +We added support for Postgres read-slaves a few releases ago, but the ``batch`` +clause didn't support it until now. Thanks to @guruvignesh01 for reporting +this issue, and @sinisaos for help implementing it. + +.. code-block:: python + + # Returns 100 rows at a time from read_replica_db + async with await Manager.select().batch( + batch_size=100, + node="read_replica_db", + ) as batch: + async for _batch in batch: + print(_batch) + + +------------------------------------------------------------------------------- + +0.82.0 +------ + +Traditionally, when instantiating a ``Table``, you passed in column values +using kwargs: + +.. code-block:: python + + >>> await Manager(name='Guido').save() + +You can now pass in a dictionary instead, which makes it easier for static +typing analysis tools like Mypy to detect typos. + +.. code-block:: python + + >>> await Manager({Manager.name: 'Guido'}).save() + +See `PR 565 `_ for more info. + +------------------------------------------------------------------------------- + +0.81.0 +------ + +Added the ``returning`` clause to ``insert`` and ``update`` queries. + +This can be used to retrieve data from the inserted / modified rows. + +Here's an example, where we update the unpopular bands, and retrieve their +names, in a single query: + +.. code-block:: python + + >>> await Band.update({ + ... Band.popularity: Band.popularity + 5 + ... }).where( + ... Band.popularity < 10 + ... ).returning( + ... Band.name + ... ) + [{'name': 'Bad sound band'}, {'name': 'Tone deaf band'}] + +See `PR 564 `_ and +`PR 563 `_ for more info. + +------------------------------------------------------------------------------- + +0.80.2 +------ + +Fixed a bug with ``Combination.__str__``, which meant that when printing out a +query for debugging purposes it was wasn't showing correctly (courtesy +@destos). + +------------------------------------------------------------------------------- + +0.80.1 +------ + +Fixed a bug with Piccolo Admin and ``_get_related_readable``, which is used +to show a human friendly identifier for a row, rather than just the ID. + +Thanks to @ethagnawl and @sinisaos for their help with this. + +------------------------------------------------------------------------------- + +0.80.0 +------ + +There was a bug when doing joins with a ``JSONB`` column with ``as_alias``. + +.. code-block:: python + + class User(Table, tablename="my_user"): + name = Varchar(length=120) + config = JSONB(default={}) + + + class Subscriber(Table, tablename="subscriber"): + name = Varchar(length=120) + user = ForeignKey(references=User) + + + async def main(): + # This was failing: + await Subscriber.select( + Subscriber.name, + Subscriber.user.config.as_alias("config") + ) + +Thanks to @Anton-Karpenko for reporting this issue. + +Even though this is a bug fix, the minor version number has been bumped because +the fix resulted in some refactoring of Piccolo's internals, so is a fairly big +change. + +------------------------------------------------------------------------------- + +0.79.0 +------ + +Added a custom ``__repr__`` method to ``Table``'s metaclass. It's needed to +improve the appearance of our Sphinx docs. See +`issue 549 `_ for more +details. + +------------------------------------------------------------------------------- + +0.78.0 +------ + +Added the ``callback`` clause to ``select`` and ``objects`` queries (courtesy +@backwardspy). For example: + +.. code-block:: python + + >>> await Band.select().callback(my_callback) + +The callback can be a normal function or async function, which is called when +the query is successful. The callback can be used to modify the query's output. + +It allows for some interesting and powerful code. Here's a very simple example +where we modify the query's output: + +.. code-block:: python + + >>> def get_uppercase_names() -> Select: + ... def make_uppercase(response): + ... return [{'name': i['name'].upper()} for i in response] + ... + ... return Band.select(Band.name).callback(make_uppercase) + + >>> await get_uppercase_names().where(Band.manager.name == 'Guido') + [{'name': 'PYTHONISTAS'}] + +Here's another example, where we perform validation on the query's output: + +.. code-block:: python + + >>> def get_concerts() -> Select: + ... def check_length(response): + ... if len(response) == 0: + ... raise ValueError('No concerts!') + ... return response + ... + ... return Concert.select().callback(check_length) + + >>> await get_concerts().where(Concert.band_1.name == 'Terrible Band') + ValueError: No concerts! + +At the moment, callbacks are just triggered when a query is successful, but in +the future other callbacks will be added, to hook into more of Piccolo's +internals. + +------------------------------------------------------------------------------- + +0.77.0 +------ + +Added the ``refresh`` method. If you have an object which has gotten stale, and +want to refresh it, so it has the latest data from the database, you can now do +this: + +.. code-block:: python + + # If we have an instance: + band = await Band.objects().first() + + # And it has gotten stale, we can refresh it: + await band.refresh() + +Thanks to @trondhindenes for suggesting this feature. + +------------------------------------------------------------------------------- + +0.76.1 +------ + +Fixed a bug with ``atomic`` when run async with a connection pool. + +For example: + +.. code-block:: python + + atomic = Band._meta.db.atomic() + atomic.add(query_1, query_1) + # This was failing: + await atomic.run() + +Thanks to @Anton-Karpenko for reporting this issue. + +------------------------------------------------------------------------------- + +0.76.0 +------ + +create_db_tables / drop_db_tables +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Added ``create_db_tables`` and ``create_db_tables_sync`` to replace +``create_tables``. The problem was ``create_tables`` was sync only, and was +inconsistent with the rest of Piccolo's API, which is async first. +``create_tables`` will continue to work for now, but is deprecated, and will be +removed in version 1.0. + +Likewise, ``drop_db_tables`` and ``drop_db_tables_sync`` have replaced +``drop_tables``. + +When calling ``create_tables`` / ``drop_tables`` within other async libraries +(such as `ward `_) it was sometimes +unreliable - the best solution was just to make async versions of these +functions. Thanks to @backwardspy for reporting this issue. + +``BaseUser`` password validation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +We centralised the password validation logic in ``BaseUser`` into a method +called ``_validate_password``. This is needed by Piccolo API, but also makes it +easier for users to override this logic if subclassing ``BaseUser``. + +More ``run_sync`` refinements +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +``run_sync``, which is the main utility function which Piccolo uses to run +async code, has been further simplified for Python > v3.10 compatibility. + +------------------------------------------------------------------------------- + +0.75.0 +------ + +Changed how ``piccolo.utils.sync.run_sync`` works, to prevent a warning on +Python 3.10. Thanks to @Drapersniper for reporting this issue. + +Lots of documentation improvements - particularly around testing, and Docker +deployment. + +------------------------------------------------------------------------------- + +0.74.4 +------ + +``piccolo schema generate`` now outputs a warning when it can't detect the +``ON DELETE`` and ``ON UPDATE`` for a ``ForeignKey``, rather than raising an +exception. Thanks to @theelderbeever for reporting this issue. + +``run_sync`` doesn't use the connection pool by default anymore. It was causing +issues when an app contained sync and async code. Thanks to @WintonLi for +reporting this issue. + +Added a tutorial to the docs for using Piccolo with an existing project and +database. Thanks to @virajkanwade for reporting this issue. + +------------------------------------------------------------------------------- + +0.74.3 +------ + +If you had a table containing an array of ``BigInt``, then migrations could +fail: + +.. code-block:: python + + from piccolo.table import Table + from piccolo.columns.column_types import Array, BigInt + + class MyTable(Table): + my_column = Array(base_column=BigInt()) + +It's because the ``BigInt`` base column needs access to the parent table to +know if it's targeting Postgres or SQLite. See `PR 501 `_. + +Thanks to @cheesycod for reporting this issue. + +------------------------------------------------------------------------------- + +0.74.2 +------ + +If a user created a custom ``Column`` subclass, then migrations would fail. +For example: + +.. code-block:: python + + class CustomColumn(Varchar): + def __init__(self, custom_arg: str = '', *args, **kwargs): + self.custom_arg = custom_arg + super().__init__(*args, **kwargs) + + @property + def column_type(self): + return 'VARCHAR' + +See `PR 497 `_. Thanks to +@WintonLi for reporting this issue. + +------------------------------------------------------------------------------- + +0.74.1 +------ + +When using ``pip install piccolo[all]`` on Windows it would fail because uvloop +isn't supported. Thanks to @jack1142 for reporting this issue. + +------------------------------------------------------------------------------- + +0.74.0 +------ + +We've had the ability to bulk modify rows for a while. Here we append ``'!!!'`` +to each band's name: + +.. code-block:: python + + >>> await Band.update({Band.name: Band.name + '!!!'}, force=True) + +It only worked for some columns - ``Varchar``, ``Text``, ``Integer`` etc. + +We now allow ``Date``, ``Timestamp``, ``Timestamptz`` and ``Interval`` columns +to be bulk modified using a ``timedelta``. Here we modify each concert's start +date, so it's one day later: + +.. code-block:: python + + >>> await Concert.update( + ... {Concert.starts: Concert.starts + timedelta(days=1)}, + ... force=True + ... ) + +Thanks to @theelderbeever for suggesting this feature. + +------------------------------------------------------------------------------- + +0.73.0 +------ + +You can now specify extra nodes for a database. For example, if you have a +read replica. + +.. code-block:: python + + DB = PostgresEngine( + config={'database': 'main_db'}, + extra_nodes={ + 'read_replica_1': PostgresEngine( + config={ + 'database': 'main_db', + 'host': 'read_replica_1.my_db.com' + } + ) + } + ) + +And can then run queries on these other nodes: + +.. code-block:: python + + >>> await MyTable.select().run(node="read_replica_1") + +See `PR 481 `_. Thanks to +@dashsatish for suggesting this feature. + +Also, the ``targ`` library has been updated so it tells users about the +``--trace`` argument which can be used to get a full traceback when a CLI +command fails. + +------------------------------------------------------------------------------- + +0.72.0 +------ + +Fixed typos with ``drop_constraints``. Courtesy @smythp. + +Lots of documentation improvements, such as fixing Sphinx's autodoc for the +``Array`` column. + +``AppConfig`` now accepts a ``pathlib.Path`` instance. For example: + +.. code-block:: python + + # piccolo_app.py + + import pathlib + + APP_CONFIG = AppConfig( + app_name="blog", + migrations_folder_path=pathlib.Path(__file__) / "piccolo_migrations" + ) + +Thanks to @theelderbeever for recommending this feature. + +------------------------------------------------------------------------------- + +0.71.1 +------ + +Fixed a bug with ``ModelBuilder`` and nullable columns (see `PR 462 `_). +Thanks to @fiolet069 for reporting this issue. + +------------------------------------------------------------------------------- + +0.71.0 +------ + +The ``ModelBuilder`` class, which is used to generate mock data in tests, now +supports ``Array`` columns. Courtesy @backwardspy. + +Lots of internal code optimisations and clean up. Courtesy @yezz123. + +Added docs for troubleshooting common MyPy errors. + +Also thanks to @adriangb for helping us with our dependency issues. + +------------------------------------------------------------------------------- + +0.70.1 +------ + +Fixed a bug with auto migrations. If renaming multiple columns at once, it +could get confused. Thanks to @theelderbeever for reporting this issue, and +@sinisaos for helping to replicate it. See `PR 457 `_. + +------------------------------------------------------------------------------- + +0.70.0 +------ + +We ran a profiler on the Piccolo codebase and identified some optimisations. +For example, we were calling ``self.querystring`` multiple times in a method, +rather than assigning it to a local variable. + +We also ran a linter which identified when list / set / dict comprehensions +could be more efficient. + +The performance is now slightly improved (especially when fetching large +numbers of rows from the database). + +Example query times on a MacBook, when fetching 1000 rows from a local Postgres +database (using ``await SomeTable.select()``): + +* 8 ms without a connection pool +* 2 ms with a connection pool + +As you can see, having a connection pool is the main thing you can do to +improve performance. + +Thanks to @AliSayyah for all his work on this. + +------------------------------------------------------------------------------- + +0.69.5 +------ + +Made improvements to ``piccolo schema generate``, which automatically generates +Piccolo ``Table`` classes from an existing database. + +There were situations where it would fail ungracefully when it couldn't parse +an index definition. It no longer crashes, and we print out the problematic +index definitions. See `PR 449 `_. +Thanks to @gmos for originally reporting this issue. + +We also improved the error messages if schema generation fails for some reason +by letting the user know which table caused the error. Courtesy @AliSayyah. + +------------------------------------------------------------------------------- + +0.69.4 +------ + +We used to raise a ``ValueError`` if a column was both ``null=False`` and +``default=None``. This has now been removed, as there are situations where +it's valid for columns to be configured that way. Thanks to @gmos for +suggesting this change. + +------------------------------------------------------------------------------- + +0.69.3 +------ + +The ``where`` clause now raises a ``ValueError`` if a boolean value is +passed in by accident. This was possible in the following situation: + +.. code-block:: python + + await Band.select().where(Band.has_drummer is None) + +Piccolo can't override the ``is`` operator because Python doesn't allow it, +so ``Band.has_drummer is None`` will always equal ``False``. Thanks to +@trondhindenes for reporting this issue. + +We've also put a lot of effort into improving documentation throughout the +project. + +------------------------------------------------------------------------------- + +0.69.2 +------ + +* Lots of documentation improvements, including how to customise ``BaseUser`` + (courtesy @sinisaos). +* Fixed a bug with creating indexes when the column name clashes with a SQL + keyword (e.g. ``'order'``). See `Pr 433 `_. + Thanks to @wmshort for reporting this issue. +* Fixed an issue where some slots were incorrectly configured (courtesy + @ariebovenberg). See `PR 426 `_. + +------------------------------------------------------------------------------- + +0.69.1 +------ + +Fixed a bug with auto migrations which rename columns - see +`PR 423 `_. Thanks to +@theelderbeever for reporting this, and @sinisaos for help investigating. + +------------------------------------------------------------------------------- + +0.69.0 +------ + +Added `Xpresso `_ as a supported ASGI framework when +using ``piccolo asgi new`` to generate a web app. + +Thanks to @sinisaos for adding this template, and @adriangb for reviewing. + +We also took this opportunity to update our FastAPI and BlackSheep ASGI +templates. + +------------------------------------------------------------------------------- + +0.68.0 +------ + +``Update`` queries without a ``where`` clause +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you try and perform an update query without a ``where`` clause you will now +get an error: + +.. code-block:: python + + >>> await Band.update({Band.name: 'New Band'}) + UpdateError + +If you want to update all rows in the table, you can still do so, but you must +pass ``force=True``. + +.. code-block:: python + + >>> await Band.update({Band.name: 'New Band'}, force=True) + +This is a similar to ``delete`` queries, which require a ``where`` clause or +``force=True``. + +It was pointed out by @theelderbeever that an accidental mass update is almost +as bad as a mass deletion, which is why this safety measure has been added. + +See `PR 412 `_. + +.. warning:: This is a breaking change. It you're doing update queries without + a where clause, you will need to add ``force=True``. + +``JSONB`` improvements +~~~~~~~~~~~~~~~~~~~~~~ + +Fixed some bugs with nullable ``JSONB`` columns. A value of ``None`` is now +stored as ``null`` in the database, instead of the JSON string ``'null'``. +Thanks to @theelderbeever for reporting this. + +See `PR 413 `_. + +------------------------------------------------------------------------------- + +0.67.0 +------ + +create_user +~~~~~~~~~~~ + +``BaseUser`` now has a ``create_user`` method, which adds some extra password +validation vs just instantiating and saving ``BaseUser`` directly. + +.. code-block:: python + + >>> await BaseUser.create_user(username='bob', password='abc123XYZ') + + +We check that passwords are a reasonable length, and aren't already hashed. +See `PR 402 `_. + +async first +~~~~~~~~~~~ + +All of the docs have been updated to show the async version of queries. + +For example: + +.. code-block:: python + + # Previous: + Band.select().run_sync() + + # Now: + await Band.select() + +Most people use Piccolo in async apps, and the playground supports top level +await, so you can just paste in ``await Band.select()`` and it will still work. +See `PR 407 `_. + +We decided to use ``await Band.select()`` instead of ``await Band.select().run()``. +Both work, and have their merits, but the simpler version is probably easier +for newcomers. + +------------------------------------------------------------------------------- + +0.66.1 +------ + +In Piccolo you can print out any query to see the SQL which will be generated: + +.. code-block:: python + + >>> print(Band.select()) + SELECT "band"."id", "band"."name", "band"."manager", "band"."popularity" FROM band + +It didn't represent ``UUID`` and ``datetime`` values correctly, which is now fixed (courtesy @theelderbeever). +See `PR 405 `_. + +------------------------------------------------------------------------------- + +0.66.0 +------ + +Using descriptors to improve MyPy support (`PR 399 `_). + +MyPy is now able to correctly infer the type in lots of different scenarios: + +.. code-block:: python + + class Band(Table): + name = Varchar() + + # MyPy knows this is a Varchar + Band.name + + band = Band() + band.name = "Pythonistas" # MyPy knows we can assign strings when it's a class instance + band.name # MyPy knows we will get a string back + + band.name = 1 # MyPy knows this is an error, as we should only be allowed to assign strings + +------------------------------------------------------------------------------- + +0.65.1 +------ + +Fixed bug with ``BaseUser`` and Piccolo API. + +------------------------------------------------------------------------------- + +0.65.0 +------ + +The ``BaseUser`` table hashes passwords before storing them in the database. + +When we create a fixture from the ``BaseUser`` table (using ``piccolo fixtures dump``), +it looks something like: + +.. code-block:: json + + { + "id": 11, + "username": "bob", + "password": "pbkdf2_sha256$10000$abc123", + } + +When we load the fixture (using ``piccolo fixtures load``) we need to be +careful in case ``BaseUser`` tries to hash the password again (it would then be a hash of +a hash, and hence incorrect). We now have additional checks in place to prevent +this. + +Thanks to @mrbazzan for implementing this, and @sinisaos for help reviewing. + +------------------------------------------------------------------------------- + +0.64.0 +------ + +Added initial support for ``ForeignKey`` columns referencing non-primary key +columns. For example: + +.. code-block:: python + + class Manager(Table): + name = Varchar() + email = Varchar(unique=True) + + class Band(Table): + manager = ForeignKey(Manager, target_column=Manager.email) + +Thanks to @theelderbeever for suggesting this feature, and with help testing. + +------------------------------------------------------------------------------- + +0.63.1 +------ + +Fixed an issue with the ``value_type`` of ``ForeignKey`` columns when +referencing a table with a custom primary key column (such as a ``UUID``). + +------------------------------------------------------------------------------- + +0.63.0 +------ + +Added an ``exclude_imported`` option to ``table_finder``. + +.. code-block:: python + + APP_CONFIG = AppConfig( + table_classes=table_finder(['music.tables'], exclude_imported=True) + ) + +It's useful when we want to import ``Table`` subclasses defined within a +module itself, but not imported ones: + +.. code-block:: python + + # tables.py + from piccolo.apps.user.tables import BaseUser # excluded + from piccolo.columns.column_types import ForeignKey, Varchar + from piccolo.table import Table + + + class Musician(Table): # included + name = Varchar() + user = ForeignKey(BaseUser) + +This was also possible using tags, but was less convenient. Thanks to @sinisaos +for reporting this issue. + +------------------------------------------------------------------------------- + +0.62.3 +------ + +Fixed the error message in ``LazyTableReference``. + +Fixed a bug with ``create_pydantic_model`` with nested models. For example: + +.. code-block:: python + + create_pydantic_model(Band, nested=(Band.manager,)) + +Sometimes Pydantic couldn't uniquely identify the nested models. Thanks to +@wmshort and @sinisaos for their help with this. + +------------------------------------------------------------------------------- + +0.62.2 +------ + +Added a max password length to the ``BaseUser`` table. By default it's set to +128 characters. + +------------------------------------------------------------------------------- + +0.62.1 +------ + +Fixed a bug with ``Readable`` when it contains lots of joins. + +``Readable`` is used to create a user friendly representation of a row in +Piccolo Admin. + +------------------------------------------------------------------------------- + +0.62.0 +------ + +Added Many-To-Many support. + +.. code-block:: python + + from piccolo.columns.column_types import ( + ForeignKey, + LazyTableReference, + Varchar + ) + from piccolo.columns.m2m import M2M + + + class Band(Table): + name = Varchar() + genres = M2M(LazyTableReference("GenreToBand", module_path=__name__)) + + + class Genre(Table): + name = Varchar() + bands = M2M(LazyTableReference("GenreToBand", module_path=__name__)) + + + # This is our joining table: + class GenreToBand(Table): + band = ForeignKey(Band) + genre = ForeignKey(Genre) + + + >>> await Band.select(Band.name, Band.genres(Genre.name, as_list=True)) + [ + { + "name": "Pythonistas", + "genres": ["Rock", "Folk"] + }, + ... + ] + +See the docs for more details. + +Many thanks to @sinisaos and @yezz123 for all the input. + +------------------------------------------------------------------------------- + +0.61.2 +------ + +Fixed some edge cases where migrations would fail if a column name clashed with +a reserved Postgres keyword (for example ``order`` or ``select``). + +We now have more robust tests for ``piccolo asgi new`` - as part of our CI we +actually run the generated ASGI app to make sure it works (thanks to @AliSayyah +and @yezz123 for their help with this). + +We also improved docstrings across the project. + +------------------------------------------------------------------------------- + +0.61.1 +------ + +Nicer ASGI template +~~~~~~~~~~~~~~~~~~~ + +When using ``piccolo asgi new`` to generate a web app, it now has a nicer home +page template, with improved styles. + +Improved schema generation +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Fixed a bug with ``piccolo schema generate`` where it would crash if the column +type was unrecognised, due to failing to parse the column's default value. +Thanks to @gmos for reporting this issue, and figuring out the fix. + +Fix Pylance error +~~~~~~~~~~~~~~~~~ + +Added ``start_connection_pool`` and ``close_connection_pool`` methods to the +base ``Engine`` class (courtesy @gmos). + +------------------------------------------------------------------------------- + +0.61.0 +------ + +The ``save`` method now supports a ``columns`` argument, so when updating a +row you can specify which values to sync back. For example: + +.. code-block:: python + + band = await Band.objects().get(Band.name == "Pythonistas") + band.name = "Super Pythonistas" + await band.save([Band.name]) + + # Alternatively, strings are also supported: + await band.save(['name']) + +Thanks to @trondhindenes for suggesting this feature. + +------------------------------------------------------------------------------- + +0.60.2 +------ + +Fixed a bug with ``asyncio.gather`` not working with some query types. It was +due to them being dataclasses, and they couldn't be hashed properly. Thanks to +@brnosouza for reporting this issue. + +------------------------------------------------------------------------------- + +0.60.1 +------ + +Modified the import path for ``MigrationManager`` in migration files. It was +confusing Pylance (VSCode's type checker). Thanks to @gmos for reporting and +investigating this issue. + +------------------------------------------------------------------------------- + +0.60.0 +------ + +Secret columns +~~~~~~~~~~~~~~ + +All column types can now be secret, rather than being limited to the +``Secret`` column type which is a ``Varchar`` under the hood (courtesy +@sinisaos). + +.. code-block:: python + + class Manager(Table): + name = Varchar() + net_worth = Integer(secret=True) + +The reason this is useful is you can do queries such as: + +.. code-block:: python + + >>> Manager.select(exclude_secrets=True).run_sync() + [{'id': 1, 'name': 'Guido'}] + +In the Piccolo API project we have ``PiccoloCRUD`` which is an incredibly +powerful way of building an API with very little code. ``PiccoloCRUD`` has an +``exclude_secrets`` option which lets you safely expose your data without +leaking sensitive information. + +Pydantic improvements +~~~~~~~~~~~~~~~~~~~~~ + +max_recursion_depth +******************* + +``create_pydantic_model`` now has a ``max_recursion_depth`` argument, which is +useful when using ``nested=True`` on large database schemas. + +.. code-block:: python + + >>> create_pydantic_model(MyTable, nested=True, max_recursion_depth=3) + +Nested tuple +************ + +You can now pass a tuple of columns as the argument to ``nested``: + +.. code-block:: python + + >>> create_pydantic_model(Band, nested=(Band.manager,)) + +This gives you more control than just using ``nested=True``. + +include_columns / exclude_columns +********************************* + +You can now include / exclude columns from related tables. For example: + +.. code-block:: python + + >>> create_pydantic_model(Band, nested=(Band.manager,), exclude_columns=(Band.manager.country)) + +Similarly: + +.. code-block:: python + + >>> create_pydantic_model(Band, nested=(Band.manager,), include_columns=(Band.name, Band.manager.name)) + +------------------------------------------------------------------------------- + +0.59.0 +------ + +* When using ``piccolo asgi new`` to generate a FastAPI app, the generated code + is now cleaner. It also contains a ``conftest.py`` file, which encourages + people to use ``piccolo tester run`` rather than using ``pytest`` directly. +* Tidied up docs, and added logo. +* Clarified the use of the ``PICCOLO_CONF`` environment variable in the docs + (courtesy @theelderbeever). +* ``create_pydantic_model`` now accepts an ``include_columns`` argument, in + case you only want a few columns in your model, it's faster than using + ``exclude_columns`` (courtesy @sinisaos). +* Updated linters, and fixed new errors. + +------------------------------------------------------------------------------- + +0.58.0 +------ + +Improved Pydantic docs +~~~~~~~~~~~~~~~~~~~~~~ + +The Pydantic docs used to be in the Piccolo API repo, but have been moved over +to this repo. We took this opportunity to improve them significantly with +additional examples. Courtesy @sinisaos. + +Internal code refactoring +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Some of the code has been optimised and cleaned up. Courtesy @yezz123. + +Schema generation for recursive foreign keys +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When using ``piccolo schema generate``, it would get stuck in a loop if a +table had a foreign key column which referenced itself. Thanks to @knguyen5 +for reporting this issue, and @wmshort for implementing the fix. The output +will now look like: + +.. code-block:: python + + class Employee(Table): + name = Varchar() + manager = ForeignKey("self") + +Fixing a bug with Alter.add_column +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When using the ``Alter.add_column`` API directly (not via migrations), it would +fail with foreign key columns. For example: + +.. code-block:: python + + SomeTable.alter().add_column( + name="my_fk_column", + column=ForeignKey(SomeOtherTable) + ).run_sync() + +This has now been fixed. Thanks to @wmshort for discovering this issue. + +create_pydantic_model improvements +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Additional fields can now be added to the Pydantic schema. This is useful +when using Pydantic's JSON schema functionality: + +.. code-block:: python + + my_model = create_pydantic_model(Band, my_extra_field="Hello") + >>> my_model.schema() + {..., "my_extra_field": "Hello"} + +This feature was added to support new features in Piccolo Admin. + +Fixing a bug with import clashes in migrations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In certain situations it was possible to create a migration file with clashing +imports. For example: + +.. code-block:: python + + from uuid import UUID + from piccolo.columns.column_types import UUID + +Piccolo now tries to detect these clashes, and prevent them. If they can't be +prevented automatically, a warning is shown to the user. Courtesy @0scarB. + +------------------------------------------------------------------------------- + +0.57.0 +------ + +Added Python 3.10 support (courtesy @kennethcheo). + +------------------------------------------------------------------------------- + +0.56.0 +------ + +Fixed schema generation bug +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When using ``piccolo schema generate`` to auto generate Piccolo ``Table`` +classes from an existing database, it would fail in this situation: + +* A table has a column with an index. +* The column name clashed with a Postgres type. + +For example, we couldn't auto generate this ``Table`` class: + +.. code-block:: python + + class MyTable(Table): + time = Timestamp(index=True) + +This is because ``time`` is a builtin Postgres type, and the ``CREATE INDEX`` +statement being inspected in the database wrapped the column name in quotes, +which broke our regex. + +Thanks to @knguyen5 for fixing this. + +Improved testing docs +~~~~~~~~~~~~~~~~~~~~~ + +A convenience method called ``get_table_classes`` was added to ``Finder``. + +``Finder`` is the main class in Piccolo for dynamically importing projects / +apps / tables / migrations etc. + +``get_table_classes`` lets us easily get the ``Table`` classes for a project. +This makes writing unit tests easier, when we need to setup a schema. + +.. code-block:: python + + from unittest import TestCase + + from piccolo.table import create_tables, drop_tables + from piccolo.conf.apps import Finder + + TABLES = Finder().get_table_classes() + + class TestApp(TestCase): + def setUp(self): + create_tables(*TABLES) + + def tearDown(self): + drop_tables(*TABLES) + + def test_app(self): + # Do some testing ... + pass + +The docs were updated to reflect this. + +When dropping tables in a unit test, remember to use ``piccolo tester run``, to +make sure the test database is used. + +get_output_schema +~~~~~~~~~~~~~~~~~ + +``get_output_schema`` is the main entrypoint for database reflection in +Piccolo. It has been modified to accept an optional ``Engine`` argument, which +makes it more flexible. + +------------------------------------------------------------------------------- + +0.55.0 +------ + +Table._meta.refresh_db +~~~~~~~~~~~~~~~~~~~~~~ + +Added the ability to refresh the database engine. + +.. code-block:: python + + MyTable._meta.refresh_db() + +This causes the ``Table`` to fetch the ``Engine`` again from your +``piccolo_conf.py`` file. The reason this is useful, is you might change the +``PICCOLO_CONF`` environment variable, and some ``Table`` classes have +already imported an engine. This is now used by the ``piccolo tester run`` +command to ensure all ``Table`` classes have the correct engine. + +ColumnMeta edge cases +~~~~~~~~~~~~~~~~~~~~~ + +Fixed an edge case where ``ColumnMeta`` couldn't be copied if it had extra +attributes added to it. + +Improved column type conversion +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When running migrations which change column types, Piccolo now provides the +``USING`` clause to the ``ALTER COLUMN`` DDL statement, which makes it more +likely that type conversion will be successful. + +For example, if there is an ``Integer`` column, and it's converted to a +``Varchar`` column, the migration will run fine. In the past, running this in +reverse would fail. Now Postgres will try and cast the values back to integers, +which makes reversing migrations more likely to succeed. + +Added drop_tables +~~~~~~~~~~~~~~~~~ + +There is now a convenience function for dropping several tables in one go. If +the database doesn't support ``CASCADE``, then the tables are sorted based on +their ``ForeignKey`` columns, so they're dropped in the correct order. It all +runs inside a transaction. + +.. code-block:: python + + from piccolo.table import drop_tables + + drop_tables(Band, Manager) + +This is a useful tool in unit tests. + +Index support in schema generation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When using ``piccolo schema generate``, Piccolo will now reflect the indexes +from the database into the generated ``Table`` classes. Thanks to @wmshort for +this. + +------------------------------------------------------------------------------- + +0.54.0 +------ +Added the ``db_column_name`` option to columns. This is for edge cases where +a legacy database is being used, with problematic column names. For example, +if a column is called ``class``, this clashes with a Python builtin, so the +following isn't possible: + +.. code-block:: text + + class MyTable(Table): + class = Varchar() # Syntax error! + +You can now do the following: + +.. code-block:: python + + class MyTable(Table): + class_ = Varchar(db_column_name='class') + +Here are some example queries using it: + +.. code-block:: python + + # Create - both work as expected + MyTable(class_='Test').save().run_sync() + MyTable.objects().create(class_='Test').run_sync() + + # Objects + row = MyTable.objects().first().where(MyTable.class_ == 'Test').run_sync() + >>> row.class_ + 'Test' + + # Select + >>> MyTable.select().first().where(MyTable.class_ == 'Test').run_sync() + {'id': 1, 'class': 'Test'} + +------------------------------------------------------------------------------- + +0.53.0 +------ +An internal code clean up (courtesy @yezz123). + +Dramatically improved CLI appearance when running migrations (courtesy +@wmshort). + +Added a runtime reflection feature, where ``Table`` classes can be generated +on the fly from existing database tables (courtesy @AliSayyah). This is useful +when dealing with very dynamic databases, where tables are frequently being +added / modified, so hard coding them in a ``tables.py`` is impractical. Also, +for exploring databases on the command line. It currently just supports +Postgres. + +Here's an example: + +.. code-block:: python + + from piccolo.table_reflection import TableStorage + + storage = TableStorage() + Band = await storage.get_table('band') + >>> await Band.select().run() + [{'id': 1, 'name': 'Pythonistas', 'manager': 1}, ...] + +------------------------------------------------------------------------------- + +0.52.0 +------ +Lots of improvements to ``piccolo schema generate``: + +* Dramatically improved performance, by executing more queries in parallel + (courtesy @AliSayyah). +* If a table in the database has a foreign key to a table in another + schema, this will now work (courtesy @AliSayyah). +* The column defaults are now extracted from the database (courtesy @wmshort). +* The ``scale`` and ``precision`` values for ``Numeric`` / ``Decimal`` column + types are extracted from the database (courtesy @wmshort). +* The ``ON DELETE`` and ``ON UPDATE`` values for ``ForeignKey`` columns are + now extracted from the database (courtesy @wmshort). + +Added ``BigSerial`` column type (courtesy @aliereno). + +Added GitHub issue templates (courtesy @AbhijithGanesh). + +------------------------------------------------------------------------------- + +0.51.1 +------ +Fixing a bug with ``on_delete`` and ``on_update`` not being set correctly. +Thanks to @wmshort for discovering this. + +------------------------------------------------------------------------------- + +0.51.0 +------ +Modified ``create_pydantic_model``, so ``JSON`` and ``JSONB`` columns have a +``format`` attribute of ``'json'``. This will be used by Piccolo Admin for +improved JSON support. Courtesy @sinisaos. + +Fixing a bug where the ``piccolo fixtures load`` command wasn't registered +with the Piccolo CLI. + +------------------------------------------------------------------------------- + +0.50.0 +------ +The ``where`` clause can now accept multiple arguments (courtesy @AliSayyah): + +.. code-block:: python + + Concert.select().where( + Concert.venue.name == 'Royal Albert Hall', + Concert.band_1.name == 'Pythonistas' + ).run_sync() + +It's another way of expressing `AND`. It's equivalent to both of these: + +.. code-block:: python + + Concert.select().where( + Concert.venue.name == 'Royal Albert Hall' + ).where( + Concert.band_1.name == 'Pythonistas' + ).run_sync() + + Concert.select().where( + (Concert.venue.name == 'Royal Albert Hall') & (Concert.band_1.name == 'Pythonistas') + ).run_sync() + +Added a ``create`` method, which is an easier way of creating objects (courtesy +@AliSayyah). + +.. code-block:: python + + # This still works: + band = Band(name="C-Sharps", popularity=100) + band.save().run_sync() + + # But now we can do it in a single line using `create`: + band = Band.objects().create(name="C-Sharps", popularity=100).run_sync() + +Fixed a bug with ``piccolo schema generate`` where columns with unrecognised +column types were omitted from the output (courtesy @AliSayyah). + +Added docs for the ``--trace`` argument, which can be used with Piccolo +commands to get a traceback if the command fails (courtesy @hipertracker). + +Added ``DoublePrecision`` column type, which is similar to ``Real`` in that +it stores ``float`` values. However, those values are stored at greater +precision (courtesy @AliSayyah). + +Improved ``AppRegistry``, so if a user only adds the app name (e.g. ``blog``), +instead of ``blog.piccolo_app``, it will now emit a warning, and will try to +import ``blog.piccolo_app`` (courtesy @aliereno). + +------------------------------------------------------------------------------- + +0.49.0 +------ +Fixed a bug with ``create_pydantic_model`` when used with a ``Decimal`` / +``Numeric`` column when no ``digits`` arguments was set (courtesy @AliSayyah). + +Added the ``create_tables`` function, which accepts a sequence of ``Table`` +subclasses, then sorts them based on their ``ForeignKey`` columns, and creates +them. This is really useful for people who aren't using migrations (for +example, when using Piccolo in a simple data science script). Courtesy +@AliSayyah. + +.. code-block:: python + + from piccolo.tables import create_tables + + create_tables(Band, Manager, if_not_exists=True) + + # Equivalent to: + Manager.create_table(if_not_exists=True).run_sync() + Band.create_table(if_not_exists=True).run_sync() + +Fixed typos with the new fixtures app - sometimes it was referred to as +``fixture`` and other times ``fixtures``. It's now standardised as +``fixtures`` (courtesy @hipertracker). + +------------------------------------------------------------------------------- + +0.48.0 +------ +The ``piccolo user create`` command can now be used by passing in command line +arguments, instead of using the interactive prompt (courtesy @AliSayyah). + +For example ``piccolo user create --username=bob ...``. + +This is useful when you want to create users in a script. + +------------------------------------------------------------------------------- + +0.47.0 +------ +You can now use ``pip install piccolo[all]``, which will install all optional +requirements. + +------------------------------------------------------------------------------- + +0.46.0 +------ +Added the fixtures app. This is used to dump data from a database to a JSON +file, and then reload it again. It's useful for seeding a database with +essential data, whether that's a colleague setting up their local environment, +or deploying to production. + +To create a fixture: + +.. code-block:: bash + + piccolo fixtures dump --apps=blog > fixture.json + +To load a fixture: + +.. code-block:: bash + + piccolo fixtures load fixture.json + +As part of this change, Piccolo's Pydantic support was brought into this +library (prior to this it only existed within the ``piccolo_api`` library). At +a later date, the ``piccolo_api`` library will be updated, so it's Pydantic +code just proxies to what's within the main ``piccolo`` library. + +------------------------------------------------------------------------------- + +0.45.1 +------ +Improvements to ``piccolo schema generate``. It's now smarter about which +imports to include. Also, the ``Table`` classes output will now be sorted based +on their ``ForeignKey`` columns. Internally the sorting algorithm has been +changed to use the ``graphlib`` module, which was added in Python 3.9. + +------------------------------------------------------------------------------- + +0.45.0 +------ +Added the ``piccolo schema graph`` command for visualising your database +structure, which outputs a Graphviz file. It can then be turned into an +image, for example: + +.. code-block:: bash + + piccolo schema map | dot -Tpdf -o graph.pdf + +Also made some minor changes to the ASGI templates, to reduce MyPy errors. + +------------------------------------------------------------------------------- + +0.44.1 +------ +Updated ``to_dict`` so it works with nested objects, as introduced by the +``prefetch`` functionality. + +For example: + +.. code-block:: python + + band = Band.objects(Band.manager).first().run_sync() + + >>> band.to_dict() + {'id': 1, 'name': 'Pythonistas', 'manager': {'id': 1, 'name': 'Guido'}} + +It also works with filtering: + +.. code-block:: python + + >>> band.to_dict(Band.name, Band.manager.name) + {'name': 'Pythonistas', 'manager': {'name': 'Guido'}} + +------------------------------------------------------------------------------- + +0.44.0 +------ +Added the ability to prefetch related objects. Here's an example: + +.. code-block:: python + + band = await Band.objects(Band.manager).run() + >>> band.manager + + +If a table has a lot of ``ForeignKey`` columns, there's a useful shortcut, +which will return all of the related rows as objects. + +.. code-block:: python + + concert = await Concert.objects(Concert.all_related()).run() + >>> concert.band_1 + + >>> concert.band_2 + + >>> concert.venue + + +Thanks to @wmshort for all the input. + +------------------------------------------------------------------------------- + +0.43.0 +------ +Migrations containing ``Array``, ``JSON`` and ``JSONB`` columns should be +more reliable now. More unit tests were added to cover edge cases. + +------------------------------------------------------------------------------- + +0.42.0 +------ +You can now use ``all_columns`` at the root. For example: + +.. code-block:: python + + await Band.select( + Band.all_columns(), + Band.manager.all_columns() + ).run() + +You can also exclude certain columns if you like: + +.. code-block:: python + + await Band.select( + Band.all_columns(exclude=[Band.id]), + Band.manager.all_columns(exclude=[Band.manager.id]) + ).run() + +------------------------------------------------------------------------------- + +0.41.1 +------ +Fix a regression where if multiple tables are created in a single migration +file, it could potentially fail by applying them in the wrong order. + +------------------------------------------------------------------------------- + +0.41.0 +------ +Fixed a bug where if ``all_columns`` was used two or more levels deep, it would +fail. Thanks to @wmshort for reporting this issue. + +Here's an example: + +.. code-block:: python + + Concert.select( + Concert.venue.name, + *Concert.band_1.manager.all_columns() + ).run_sync() + +Also, the ``ColumnsDelegate`` has now been tweaked, so unpacking of +``all_columns`` is optional. + +.. code-block:: python + + # This now works the same as the code above (we have omitted the *) + Concert.select( + Concert.venue.name, + Concert.band_1.manager.all_columns() + ).run_sync() + +------------------------------------------------------------------------------- + +0.40.1 +------ +Loosen the ``typing-extensions`` requirement, as it was causing issues when +installing ``asyncpg``. + +------------------------------------------------------------------------------- + +0.40.0 +------ +Added ``nested`` output option, which makes the response from a ``select`` +query use nested dictionaries: + +.. code-block:: python + + >>> await Band.select(Band.name, *Band.manager.all_columns()).output(nested=True).run() + [{'name': 'Pythonistas', 'manager': {'id': 1, 'name': 'Guido'}}] + +Thanks to @wmshort for the idea. + +------------------------------------------------------------------------------- + +0.39.0 +------ +Added ``to_dict`` method to ``Table``. + +If you just use ``__dict__`` on a ``Table`` instance, you get some non-column +values. By using ``to_dict`` it's just the column values. Here's an example: + +.. code-block:: python + + class MyTable(Table): + name = Varchar() + + instance = MyTable.objects().first().run_sync() + + >>> instance.__dict__ + {'_exists_in_db': True, 'id': 1, 'name': 'foo'} + + >>> instance.to_dict() + {'id': 1, 'name': 'foo'} + +Thanks to @wmshort for the idea, and @aminalaee and @sinisaos for investigating +edge cases. + +------------------------------------------------------------------------------- + +0.38.2 +------ +Removed problematic type hint which assumed pytest was installed. + +------------------------------------------------------------------------------- + +0.38.1 +------ +Minor changes to ``get_or_create`` to make sure it handles joins correctly. + +.. code-block:: python + + instance = ( + Band.objects() + .get_or_create( + (Band.name == "My new band") + & (Band.manager.name == "Excellent manager") + ) + .run_sync() + ) + +In this situation, there are two columns called ``name`` - we need to make sure +the correct value is applied if the row doesn't exist. + +------------------------------------------------------------------------------- + +0.38.0 +------ +``get_or_create`` now supports more complex where clauses. For example: + +.. code-block:: python + + row = await Band.objects().get_or_create( + (Band.name == 'Pythonistas') & (Band.popularity == 1000) + ).run() + +And you can find out whether the row was created or not using +``row._was_created``. + +Thanks to @wmshort for reporting this issue. + +------------------------------------------------------------------------------- + +0.37.0 +------ +Added ``ModelBuilder``, which can be used to generate data for tests (courtesy +@aminalaee). + +------------------------------------------------------------------------------- + +0.36.0 +------ +Fixed an issue where ``like`` and ``ilike`` clauses required a wildcard. For +example: + +.. code-block:: python + + await Manager.select().where(Manager.name.ilike('Guido%')).run() + +You can now omit wildcards if you like: + +.. code-block:: python + + await Manager.select().where(Manager.name.ilike('Guido')).run() + +Which would match on ``'guido'`` and ``'Guido'``, but not ``'Guidoxyz'``. + +Thanks to @wmshort for reporting this issue. + +------------------------------------------------------------------------------- + +0.35.0 +------ +* Improved ``PrimaryKey`` deprecation warning (courtesy @tonybaloney). +* Added ``piccolo schema generate`` which creates a Piccolo schema from an + existing database. +* Added ``piccolo tester run`` which is a wrapper around pytest, and + temporarily sets ``PICCOLO_CONF``, so a test database is used. +* Added the ``get`` convenience method (courtesy @aminalaee). It returns the + first matching record, or ``None`` if there's no match. For example: + + .. code-block:: python + + manager = await Manager.objects().get(Manager.name == 'Guido').run() + + # This is equivalent to: + manager = await Manager.objects().where(Manager.name == 'Guido').first().run() + +------------------------------------------------------------------------------- + +0.34.0 +------ +Added the ``get_or_create`` convenience method (courtesy @aminalaee). Example +usage: + +.. code-block:: python + + manager = await Manager.objects().get_or_create( + Manager.name == 'Guido' + ).run() + +------------------------------------------------------------------------------- + +0.33.1 +------ +* Bug fix, where ``compare_dicts`` was failing in migrations if any ``Column`` + had an unhashable type as an argument. For example: ``Array(default=[])``. + Thanks to @hipertracker for reporting this problem. +* Increased the minimum version of orjson, so binaries are available for Macs + running on Apple silicon (courtesy @hipertracker). + +------------------------------------------------------------------------------- + +0.33.0 +------ +Fix for auto migrations when using custom primary keys (thanks to @adriangb and +@aminalaee for investigating this issue). + +------------------------------------------------------------------------------- + +0.32.0 +------ +Migrations can now have a description, which is shown when using +``piccolo migrations check``. This makes migrations easier to identify (thanks +to @davidolrik for the idea). + +------------------------------------------------------------------------------- + +0.31.0 +------ +Added an ``all_columns`` method, to make it easier to retrieve all related +columns when doing a join. For example: + +.. code-block:: python + + await Band.select(Band.name, *Band.manager.all_columns()).first().run() + +Changed the instructions for installing additional dependencies, so they're +wrapped in quotes, to make sure it works on ZSH (i.e. +``pip install 'piccolo[postgres]'`` instead of +``pip install piccolo[postgres]``). + +------------------------------------------------------------------------------- + +0.30.0 +------ +The database drivers are now installed separately. For example: +``pip install piccolo[postgres]`` (courtesy @aminalaee). + +For some users this might be a **breaking change** - please make sure that for +existing Piccolo projects, you have either ``asyncpg``, or +``piccolo[postgres]`` in your ``requirements.txt`` file. + +------------------------------------------------------------------------------- + +0.29.0 +------ +The user can now specify the primary key column (courtesy @aminalaee). For +example: + +.. code-block:: python + + class RecordingStudio(Table): + pk = UUID(primary_key=True) + +The BlackSheep template generated by ``piccolo asgi new`` now supports mounting +of the Piccolo Admin (courtesy @sinisaos). + +------------------------------------------------------------------------------- + +0.28.0 +------ +Added aggregations functions, such as ``Sum``, ``Min``, ``Max`` and ``Avg``, +for use in select queries (courtesy @sinisaos). + +------------------------------------------------------------------------------- + +0.27.0 +------ +Added uvloop as an optional dependency, installed via `pip install piccolo[uvloop]` +(courtesy @aminalaee). uvloop is a faster implementation of the asyncio event +loop found in Python's standard library. When uvloop is installed, Piccolo will +use it to increase the performance of the Piccolo CLI, and web servers such as +Uvicorn will use it to increase the performance of your ASGI app. + +------------------------------------------------------------------------------- + +0.26.0 +------ +Added ``eq`` and ``ne`` methods to the ``Boolean`` column, which can be used +if linters complain about using ``SomeTable.some_column == True``. + +------------------------------------------------------------------------------- + +0.25.0 +------ +* Changed the migration IDs, so the timestamp now includes microseconds. This + is to make clashing migration IDs much less likely. +* Added a lot of end-to-end tests for migrations, which revealed some bugs + in ``Column`` defaults. + +------------------------------------------------------------------------------- + +0.24.1 +------ +A bug fix for migrations. See `issue 123 `_ +for more information. + +------------------------------------------------------------------------------- + +0.24.0 +------ +Lots of improvements to ``JSON`` and ``JSONB`` columns. Piccolo will now +automatically convert between Python types and JSON strings. For example, with +this schema: + +.. code-block:: python + + class RecordingStudio(Table): + name = Varchar() + facilities = JSON() + +We can now do the following: + +.. code-block:: python + + RecordingStudio( + name="Abbey Road", + facilities={'mixing_desk': True} # Will automatically be converted to a JSON string + ).save().run_sync() + +Similarly, when fetching data from a JSON column, Piccolo can now automatically +deserialise it. + +.. code-block:: python + + >>> RecordingStudio.select().output(load_json=True).run_sync() + [{'id': 1, 'name': 'Abbey Road', 'facilities': {'mixing_desk': True}] + + >>> studio = RecordingStudio.objects().first().output(load_json=True).run_sync() + >>> studio.facilities + {'mixing_desk': True} + +------------------------------------------------------------------------------- + +0.23.0 +------ +Added the ``create_table_class`` function, which can be used to create +``Table`` subclasses at runtime. This was required to fix an existing bug, +which was effecting migrations (see `issue 111 `_ +for more details). + +------------------------------------------------------------------------------- + +0.22.0 +------ +* An error is now raised if a user tries to create a Piccolo app using + ``piccolo app new`` with the same name as a builtin Python module, as it + will cause strange bugs. +* Fixing a strange bug where using an expression such as + ``Concert.band_1.manager.id`` in a query would cause an error. It only + happened if multiple joins were involved, and the last column in the chain + was ``id``. +* ``where`` clauses can now accept ``Table`` instances. For example: + ``await Band.select().where(Band.manager == some_manager).run()``, instead + of having to explicity reference the ``id``. + +------------------------------------------------------------------------------- + +0.21.2 +------ +Fixing a bug with serialising ``Enum`` instances in migrations. For example: +``Varchar(default=Colour.red)``. + +------------------------------------------------------------------------------- + +0.21.1 +------ +Fix missing imports in FastAPI and Starlette app templates. + +------------------------------------------------------------------------------- + +0.21.0 +------ +* Added a ``freeze`` method to ``Query``. +* Added BlackSheep as an option to ``piccolo asgi new``. + +------------------------------------------------------------------------------- + +0.20.0 +------ +Added ``choices`` option to ``Column``. + +------------------------------------------------------------------------------- + +0.19.1 +------ +* Added ``piccolo user change_permissions`` command. +* Added aliases for CLI commands. + +------------------------------------------------------------------------------- + +0.19.0 +------ +Changes to the ``BaseUser`` table - added a ``superuser``, and ``last_login`` +column. These are required for upgrades to Piccolo Admin. + +If you're using migrations, then running ``piccolo migrations forwards all`` +should add these new columns for you. + +If not using migrations, the ``BaseUser`` table can be upgraded using the +following DDL statements: + +.. code-block:: sql + + ALTER TABLE piccolo_user ADD COLUMN "superuser" BOOLEAN NOT NULL DEFAULT false + ALTER TABLE piccolo_user ADD COLUMN "last_login" TIMESTAMP DEFAULT null + +------------------------------------------------------------------------------- + +0.18.4 +------ +* Fixed a bug when multiple tables inherit from the same mixin (thanks to + @brnosouza). +* Added a ``log_queries`` option to ``PostgresEngine``, which is useful during + debugging. +* Added the `inflection` library for converting ``Table`` class names to + database table names. Previously, a class called ``TableA`` would wrongly + have a table called ``table`` instead of ``table_a``. +* Fixed a bug with ``SerialisedBuiltin.__hash__`` not returning a number, + which could break migrations (thanks to @sinisaos). + +------------------------------------------------------------------------------- + +0.18.3 +------ +Improved ``Array`` column serialisation - needed to fix auto migrations. + +------------------------------------------------------------------------------- + +0.18.2 +------ +Added support for filtering ``Array`` columns. + +------------------------------------------------------------------------------- + +0.18.1 +------ +Add the ``Array`` column type as a top level import in ``piccolo.columns``. + +------------------------------------------------------------------------------- + +0.18.0 +------ +* Refactored ``forwards`` and ``backwards`` commands for migrations, to make + them easier to run programatically. +* Added a simple ``Array`` column type. +* ``table_finder`` now works if just a string is passed in, instead of having + to pass in an array of strings. + +------------------------------------------------------------------------------- + +0.17.5 +------ +Catching database connection exceptions when starting the default ASGI app +created with ``piccolo asgi new`` - these errors exist if the Postgres +database hasn't been created yet. + +------------------------------------------------------------------------------- + +0.17.4 +------ +Added a ``help_text`` option to the ``Table`` metaclass. This is used in +Piccolo Admin to show tooltips. + +------------------------------------------------------------------------------- + +0.17.3 +------ +Added a ``help_text`` option to the ``Column`` constructor. This is used in +Piccolo Admin to show tooltips. + +------------------------------------------------------------------------------- + +0.17.2 +------ +* Exposing ``index_type`` in the ``Column`` constructor. +* Fixing a typo with ``start_connection_pool` and ``close_connection_pool`` - + thanks to paolodina for finding this. +* Fixing a typo in the ``PostgresEngine`` docs - courtesy of paolodina. + +------------------------------------------------------------------------------- + +0.17.1 +------ +Fixing a bug with ``SchemaSnapshot`` if column types were changed in migrations +- the snapshot didn't reflect the changes. + +------------------------------------------------------------------------------- + +0.17.0 +------ +* Migrations now directly import ``Column`` classes - this allows users to + create custom ``Column`` subclasses. Migrations previously only worked with + the builtin column types. +* Migrations now detect if the column type has changed, and will try and + convert it automatically. + +------------------------------------------------------------------------------- + +0.16.5 +------ +The Postgres extensions that ``PostgresEngine`` tries to enable at startup +can now be configured. + +------------------------------------------------------------------------------- + +0.16.4 +------ +* Fixed a bug with ``MyTable.column != None`` +* Added ``is_null`` and ``is_not_null`` methods, to avoid linting issues when + comparing with None. + +------------------------------------------------------------------------------- + +0.16.3 +------ +* Added ``WhereRaw``, so raw SQL can be used in where clauses. +* ``piccolo shell run`` now uses syntax highlighting - courtesy of Fingel. + +------------------------------------------------------------------------------- + +0.16.2 +------ +Reordering the dependencies in requirements.txt when using ``piccolo asgi new`` +as the latest FastAPI and Starlette versions are incompatible. + +------------------------------------------------------------------------------- + +0.16.1 +------ +Added ``Timestamptz`` column type, for storing datetimes which are timezone +aware. + +------------------------------------------------------------------------------- + +0.16.0 +------ +* Fixed a bug with creating a ``ForeignKey`` column with ``references="self"`` + in auto migrations. +* Changed migration file naming, so there are no characters in there which + are unsupported on Windows. + +------------------------------------------------------------------------------- + +0.15.1 +------ +Changing the status code when creating a migration, and no changes were +detected. It now returns a status code of 0, so it doesn't fail build scripts. + +------------------------------------------------------------------------------- + +0.15.0 +------ +Added ``Bytea`` / ``Blob`` column type. + +------------------------------------------------------------------------------- + +0.14.13 +------- +Fixing a bug with migrations which drop column defaults. + +------------------------------------------------------------------------------- + +0.14.12 +------- +* Fixing a bug where re-running ``Table.create(if_not_exists=True)`` would + fail if it contained columns with indexes. +* Raising a ``ValueError`` if a relative path is provided to ``ForeignKey`` + ``references``. For example, ``.tables.Manager``. The paths must be absolute + for now. + +------------------------------------------------------------------------------- + +0.14.11 +------- +Fixing a bug with ``Boolean`` column defaults, caused by the ``Table`` +metaclass not being explicit enough when checking falsy values. + +------------------------------------------------------------------------------- + +0.14.10 +------- +* The ``ForeignKey`` ``references`` argument can now be specified using a + string, or a ``LazyTableReference`` instance, rather than just a ``Table`` + subclass. This allows a ``Table`` to be specified which is in a Piccolo app, + or Python module. The ``Table`` is only loaded after imports have completed, + which prevents circular import issues. +* Faster column copying, which is important when specifying joins, e.g. + ``await Band.select(Band.manager.name).run()``. +* Fixed a bug with migrations and foreign key constraints. + +------------------------------------------------------------------------------- + +0.14.9 +------ +Modified the exit codes for the ``forwards`` and ``backwards`` commands when no +migrations are left to run / reverse. Otherwise build scripts may fail. + +------------------------------------------------------------------------------- + +0.14.8 +------ +* Improved the method signature of the ``output`` query clause (explicitly + added args, instead of using ``**kwargs``). +* Fixed a bug where ``output(as_list=True)`` would fail if no rows were found. +* Made ``piccolo migrations forwards`` command output more legible. +* Improved renamed table detection in migrations. +* Added the ``piccolo migrations clean`` command for removing orphaned rows + from the migrations table. +* Fixed a bug where ``get_migration_managers`` wasn't inclusive. +* Raising a ``ValueError`` if ``is_in`` or ``not_in`` query clauses are passed + an empty list. +* Changed the migration commands to be top level async. +* Combined ``print`` and ``sys.exit`` statements. + +------------------------------------------------------------------------------- + +0.14.7 +------ +* Added missing type annotation for ``run_sync``. +* Updating type annotations for column default values - allowing callables. +* Replaced instances of ``asyncio.run`` with ``run_sync``. +* Tidied up aiosqlite imports. + +------------------------------------------------------------------------------- + +0.14.6 +------ +* Added JSON and JSONB column types, and the arrow function for JSONB. +* Fixed a bug with the distinct clause. +* Added ``as_alias``, so select queries can override column names in the + response (i.e. SELECT foo AS bar from baz). +* Refactored JSON encoding into a separate utils file. + +------------------------------------------------------------------------------- + +0.14.5 +------ +* Removed old iPython version recommendation in the ``piccolo shell run`` and + ``piccolo playground run``, and enabled top level await. +* Fixing outstanding mypy warnings. +* Added optional requirements for the playground to setup.py + +------------------------------------------------------------------------------- + +0.14.4 +------ +* Added ``piccolo sql_shell run`` command, which launches the psql or sqlite3 + shell, using the connection parameters defined in ``piccolo_conf.py``. + This is convenient when you want to run raw SQL on your database. +* ``run_sync`` now handles more edge cases, for example if there's already + an event loop in the current thread. +* Removed asgiref dependency. + +------------------------------------------------------------------------------- + +0.14.3 +------ +* Queries can be directly awaited - ``await MyTable.select()``, as an + alternative to using the run method ``await MyTable.select().run()``. +* The ``piccolo asgi new`` command now accepts a ``name`` argument, which is + used to populate the default database name within the template. + +------------------------------------------------------------------------------- + +0.14.2 +------ +* Centralised code for importing Piccolo apps and tables - laying the + foundation for fixtures. +* Made orjson an optional dependency, installable using + ``pip install piccolo[orjson]``. +* Improved version number parsing in Postgres. + +------------------------------------------------------------------------------- + +0.14.1 +------ +Fixing a bug with dropping tables in auto migrations. + +------------------------------------------------------------------------------- + +0.14.0 +------ +Added ``Interval`` column type. + +------------------------------------------------------------------------------- + +0.13.5 +------ +* Added ``allowed_hosts`` to ``create_admin`` in ASGI template. +* Fixing bug with default ``root`` argument in some piccolo commands. + +------------------------------------------------------------------------------- + +0.13.4 +------ +* Fixed bug with ``SchemaSnapshot`` when dropping columns. +* Added custom ``__repr__`` method to ``Table``. + +------------------------------------------------------------------------------- + +0.13.3 +------ +Added ``piccolo shell run`` command for running adhoc queries using Piccolo. + +------------------------------------------------------------------------------- + +0.13.2 +------ +* Fixing bug with auto migrations when dropping columns. +* Added a ``root`` argument to ``piccolo asgi new``, ``piccolo app new`` and + ``piccolo project new`` commands, to override where the files are placed. + +------------------------------------------------------------------------------- + +0.13.1 +------ +Added support for ``group_by`` and ``Count`` for aggregate queries. + +------------------------------------------------------------------------------- + +0.13.0 +------ +Added `required` argument to ``Column``. This allows the user to indicate which +fields must be provided by the user. Other tools can use this value when +generating forms and serialisers. + +------------------------------------------------------------------------------- + +0.12.6 +------ +* Fixing a typo in ``TimestampCustom`` arguments. +* Fixing bug in ``TimestampCustom`` SQL representation. +* Added more extensive deserialisation for migrations. + +------------------------------------------------------------------------------- + +0.12.5 +------ +* Improved ``PostgresEngine`` docstring. +* Resolving rename migrations before adding columns. +* Fixed bug serialising ``TimestampCustom``. +* Fixed bug with altering column defaults to be non-static values. +* Removed ``response_handler`` from ``Alter`` query. + +------------------------------------------------------------------------------- + +0.12.4 +------ +Using orjson for JSON serialisation when using the ``output(as_json=True)`` +clause. It supports more Python types than ujson. + +------------------------------------------------------------------------------- + +0.12.3 +------ +Improved ``piccolo user create`` command - defaults the username to the current +system user. + +------------------------------------------------------------------------------- + +0.12.2 +------ +Fixing bug when sorting ``extra_definitions`` in auto migrations. + +------------------------------------------------------------------------------- + +0.12.1 +------ +* Fixed typos. +* Bumped requirements. + +------------------------------------------------------------------------------- + +0.12.0 +------ +* Added ``Date`` and ``Time`` columns. +* Improved support for column default values. +* Auto migrations can now serialise more Python types. +* Added ``Table.indexes`` method for listing table indexes. +* Auto migrations can handle adding / removing indexes. +* Improved ASGI template for FastAPI. + +------------------------------------------------------------------------------- + +0.11.8 +------ +ASGI template fix. + +------------------------------------------------------------------------------- + +0.11.7 +------ +* Improved ``UUID`` columns in SQLite - prepending 'uuid:' to the stored value + to make the type more explicit for the engine. +* Removed SQLite as an option for ``piccolo asgi new`` until auto migrations + are supported. + +------------------------------------------------------------------------------- + +0.11.6 +------ +Added support for FastAPI to ``piccolo asgi new``. + +------------------------------------------------------------------------------- + +0.11.5 +------ +Fixed bug in ``BaseMigrationManager.get_migration_modules`` - wasn't +excluding non-Python files well enough. + +------------------------------------------------------------------------------- + +0.11.4 +------ +* Stopped ``piccolo migrations new`` from creating a config.py file - was + legacy. +* Added a README file to the `piccolo_migrations` folder in the ASGI template. + +------------------------------------------------------------------------------- + +0.11.3 +------ +Fixed `__pycache__` bug when using ``piccolo asgi new``. + +------------------------------------------------------------------------------- + +0.11.2 +------ +* Showing a warning if trying auto migrations with SQLite. +* Added a command for creating a new ASGI app - ``piccolo asgi new``. +* Added a meta app for printing out the Piccolo version - + ``piccolo meta version``. +* Added example queries to the playground. + +------------------------------------------------------------------------------- + +0.11.1 +------ +* Added ``table_finder``, for use in ``AppConfig``. +* Added support for concatenating strings using an update query. +* Added more tables to the playground, with more column types. +* Improved consistency between SQLite and Postgres with ``UUID`` columns, + ``Integer`` columns, and ``exists`` queries. + +------------------------------------------------------------------------------- + +0.11.0 +------ +Added ``Numeric`` and ``Real`` column types. + +------------------------------------------------------------------------------- + +0.10.8 +------ +Fixing a bug where Postgres versions without a patch number couldn't be parsed. + +------------------------------------------------------------------------------- + +0.10.7 +------ +Improving release script. + +------------------------------------------------------------------------------- + +0.10.6 +------ +Sorting out packaging issue - old files were appearing in release. + +------------------------------------------------------------------------------- + +0.10.5 +------ +Auto migrations can now run backwards. + +------------------------------------------------------------------------------- + +0.10.4 +------ +Fixing some typos with ``Table`` imports. Showing a traceback when piccolo_conf +can't be found by ``engine_finder``. + +------------------------------------------------------------------------------- + +0.10.3 +------ +Adding missing jinja templates to setup.py. + +------------------------------------------------------------------------------- + +0.10.2 +------ +Fixing a bug when using ``piccolo project new`` in a new project. + +------------------------------------------------------------------------------- + +0.10.1 +------ +Fixing bug with enum default values. + +------------------------------------------------------------------------------- + +0.10.0 +------ +Using targ for the CLI. Refactored some core code into apps. + +------------------------------------------------------------------------------- + +0.9.3 +----- +Suppressing exceptions when trying to find the Postgres version, to avoid +an ``ImportError`` when importing `piccolo_conf.py`. + +------------------------------------------------------------------------------- + +0.9.2 +----- +``.first()`` bug fix. + +------------------------------------------------------------------------------- + +0.9.1 +----- +Auto migration fixes, and ``.first()`` method now returns None if no match is +found. + +------------------------------------------------------------------------------- + +0.9.0 +----- +Added support for auto migrations. + +------------------------------------------------------------------------------- + +0.8.3 +----- +Can use operators in update queries, and fixing 'new' migration command. + +------------------------------------------------------------------------------- + +0.8.2 +----- +Fixing release issue. + +------------------------------------------------------------------------------- + +0.8.1 +----- +Improved transaction support - can now use a context manager. Added ``Secret``, +``BigInt`` and ``SmallInt`` column types. Foreign keys can now reference the +parent table. + +------------------------------------------------------------------------------- + +0.8.0 +----- +Fixing bug when joining across several tables. Can pass values directly into +the ``Table.update`` method. Added ``if_not_exists`` option when creating a +table. + +------------------------------------------------------------------------------- + +0.7.7 +----- +Column sequencing matches the definition order. + +------------------------------------------------------------------------------- + +0.7.6 +----- +Supporting `ON DELETE` and `ON UPDATE` for foreign keys. Recording reverse +foreign key relationships. + +------------------------------------------------------------------------------- + +0.7.5 +----- +Made ``response_handler`` async. Made it easier to rename columns. + +------------------------------------------------------------------------------- + +0.7.4 +----- +Bug fixes and dependency updates. + +------------------------------------------------------------------------------- + +0.7.3 +----- +Adding missing ``__int__.py`` file. + +------------------------------------------------------------------------------- + +0.7.2 +----- +Changed migration import paths. + +------------------------------------------------------------------------------- + +0.7.1 +----- +Added ``remove_db_file`` method to ``SQLiteEngine`` - makes testing easier. + +------------------------------------------------------------------------------- + +0.7.0 +----- +Renamed ``create`` to ``create_table``, and can register commands via +`piccolo_conf`. + +------------------------------------------------------------------------------- + +0.6.1 +----- +Adding missing ``__init__.py`` files. + +------------------------------------------------------------------------------- + +0.6.0 +----- +Moved ``BaseUser``. Migration refactor. + +------------------------------------------------------------------------------- + +0.5.2 +----- +Moved drop table under ``Alter`` - to help prevent accidental drops. + +------------------------------------------------------------------------------- + +0.5.1 +----- +Added ``batch`` support. + +------------------------------------------------------------------------------- + +0.5.0 +----- +Refactored the ``Table`` Metaclass - much simpler now. Scoped more of the +attributes on ``Column`` to avoid name clashes. Added ``engine_finder`` to make +database configuration easier. + +------------------------------------------------------------------------------- + +0.4.1 +----- +SQLite is now returning datetime objects for timestamp fields. + +------------------------------------------------------------------------------- + +0.4.0 +----- +Refactored to improve code completion, along with bug fixes. + +------------------------------------------------------------------------------- + +0.3.7 +----- +Allowing ``Update`` queries in SQLite. + +------------------------------------------------------------------------------- + +0.3.6 +----- +Falling back to `LIKE` instead of `ILIKE` for SQLite. + +------------------------------------------------------------------------------- + +0.3.5 +----- +Renamed ``User`` to ``BaseUser``. + +------------------------------------------------------------------------------- + +0.3.4 +----- +Added ``ilike``. + +------------------------------------------------------------------------------- + +0.3.3 +----- +Added value types to columns. + +------------------------------------------------------------------------------- + +0.3.2 +----- +Default values infer the engine type. + +------------------------------------------------------------------------------- + +0.3.1 +----- +Update click version. + +------------------------------------------------------------------------------- + +0.3 +--- +Tweaked API to support more auto completion. Join support in where clause. +Basic SQLite support - mostly for playground. + +------------------------------------------------------------------------------- + +0.2 +--- +Using ``QueryString`` internally to represent queries, instead of raw strings, +to harden against SQL injection. + +------------------------------------------------------------------------------- + +0.1.2 +----- +Allowing joins across multiple tables. + +------------------------------------------------------------------------------- + +0.1.1 +----- +Added playground. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 30756895b..e45d15986 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1 +1,97 @@ +# Contributing + +Thanks for your interest in the Piccolo project. 👍 + +The aim of Piccolo is to build a fantastic ORM and query builder, with a world class admin GUI, which makes developers happy and productive. There are lots of ways to get involved. 🚀 + +The community is friendly and responsive. You don't have to be a Python or database ninja to contribute. 🥷 + +--- + +## Pick an appropriate issue + +If you take a look at the issues list, there are some with the [`good first issue`](https://github.com/piccolo-orm/piccolo/labels/good%20first%20issue) tag. You can pick any issue you feel confident in tackling, or even create your own issue. However, the ones marked with `good first issue` are created with newcomers in mind. + +Once you've identified an issue that you're interested in working on, leave a comment on the issue letting others know. This will prevent multiple people accidentally working on the same issue. + +### What if there are no appropriate issues? + +There are always ways to improve a project: + +- Can code coverage be improved? +- Is documentation lacking? +- Are there any typos? +- Can any code be optimised or cleaned up? + +If you can identify any areas of improvement, create an issue. + +--- + +## Tips for Pull Requests (PRs) + +### Try to keep PRs simple + +The maintainers do it alongside a day job, so very large PRs may take a long time to review. We try to leave feedback on PRs within a few days of them being opened to keep things flowing. + +### A PR doesn't have to be perfect before opening it + +It's often better to open a pull request with a simple prototype, and get feedback. + +### Avoid overly complex code + +Part of open source's appeal is anyone can contribute to it. To keep a codebase clean and maintainable, it's a constant battle to keep complexity in check. Bear this in mind when doing a PR. The code should be high quality and well documented. + +--- + +## What if my code doesn't get merged? + +Most contributions get merged. However, a PR can serve many purposes - it can spark discussion and ideas. There is no wasted effort in open source, as it always contributes to collective learning. If a PR doesn't get merged, it's usually because we decide on a different approach for solving the problem, or the complexity overhead with adding it outweighs the benefits. + +--- + +## Contributing without writing code + +Even without writing code there are lots of ways to get involved. + +### Documentation + +Is something in the documentation unclear or missing? These types of contributions are invaluable. + +### Design input + +Is there something about the design which you don't like? Getting constructive user feedback is incredibly useful. + +### Tutorials + +Can you write a blog article or tutorial about Piccolo? + +### Spreading the word + +Just starring the project or a tweet helps us a lot. + +--- + +## Git usage + +If you're not confident with Git, then the [GitHub Desktop app](https://desktop.github.com/) is highly recommended. + +--- + +## Project setup + See the main [Piccolo docs](https://piccolo-orm.readthedocs.io/en/latest/piccolo/contributing/index.html). + +--- + +## Sister projects + +The main Piccolo repo is just one piece of the ecosystem. There are other essential components which are open to contributions: + +- https://github.com/piccolo-orm/piccolo_admin +- https://github.com/piccolo-orm/piccolo_api + +--- + +## Becoming a project member + +There is no formal process for becoming a project member. Typically people become members after making significant contributions to the project, or who require specific permissions which makes contributing to Piccolo easier. diff --git a/README.md b/README.md index 2c11c1ad9..3cf8c6119 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,12 @@ -# Piccolo +![Logo](https://raw.githubusercontent.com/piccolo-orm/piccolo/master/docs/logo_hero.png "Piccolo Logo") ![Tests](https://github.com/piccolo-orm/piccolo/actions/workflows/tests.yaml/badge.svg) ![Release](https://github.com/piccolo-orm/piccolo/actions/workflows/release.yaml/badge.svg) [![Documentation Status](https://readthedocs.org/projects/piccolo-orm/badge/?version=latest)](https://piccolo-orm.readthedocs.io/en/latest/?badge=latest) [![PyPI](https://img.shields.io/pypi/v/piccolo?color=%2334D058&label=pypi)](https://pypi.org/project/piccolo/) -[![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/piccolo-orm/piccolo.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/piccolo-orm/piccolo/context:python) -[![Total alerts](https://img.shields.io/lgtm/alerts/g/piccolo-orm/piccolo.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/piccolo-orm/piccolo/alerts/) [![codecov](https://codecov.io/gh/piccolo-orm/piccolo/branch/master/graph/badge.svg?token=V19CWH7MXX)](https://codecov.io/gh/piccolo-orm/piccolo) -A fast, user friendly ORM and query builder which supports asyncio. [Read the docs](https://piccolo-orm.readthedocs.io/en/latest/). +Piccolo is a fast, user friendly ORM and query builder which supports asyncio. [Read the docs](https://piccolo-orm.readthedocs.io/en/latest/). ## Features @@ -19,6 +17,7 @@ Some of it’s stand out features are: - Tab completion support - works great with iPython and VSCode. - Batteries included - a User model, authentication, migrations, an [admin GUI](https://github.com/piccolo-orm/piccolo_admin), and more. - Modern Python - fully type annotated. +- Make your codebase modular and scalable with Piccolo apps (similar to Django apps). ## Syntax @@ -32,23 +31,23 @@ await Band.select( Band.name ).where( Band.popularity > 100 -).run() +) # Join: await Band.select( Band.name, Band.manager.name -).run() +) # Delete: await Band.delete().where( Band.popularity < 1000 -).run() +) # Update: await Band.update({Band.popularity: 10000}).where( Band.name == 'Pythonistas' -).run() +) ``` Or like a typical ORM: @@ -56,40 +55,88 @@ Or like a typical ORM: ```python # To create a new object: b = Band(name='C-Sharps', popularity=100) -await b.save().run() +await b.save() # To fetch an object from the database, and update it: -b = await Band.objects().where(Band.name == 'Pythonistas').first().run() +b = await Band.objects().get(Band.name == 'Pythonistas') b.popularity = 10000 -await b.save().run() +await b.save() # To delete: -await b.remove().run() +await b.remove() ``` ## Installation Installing with PostgreSQL driver: -``` +```bash pip install 'piccolo[postgres]' ``` Installing with SQLite driver: -``` +```bash pip install 'piccolo[sqlite]' ``` +Installing with all optional dependencies (easiest): + +```bash +pip install 'piccolo[all]' +``` + ## Building a web app? Let Piccolo scaffold you an ASGI web app, using Piccolo as the ORM: -``` +```bash piccolo asgi new ``` -[Starlette](https://www.starlette.io/), [FastAPI](https://fastapi.tiangolo.com/), and [BlackSheep](https://www.neoteroi.dev/blacksheep/) are currently supported. +[Starlette](https://www.starlette.io/), [FastAPI](https://fastapi.tiangolo.com/), [BlackSheep](https://www.neoteroi.dev/blacksheep/), [Litestar](https://litestar.dev/), [Ravyn](https://www.ravyn.dev/), [Lilya](https://lilya.dev/), [Quart](https://quart.palletsprojects.com/en/latest/), [Falcon](https://falconframework.org/) and [Sanic](https://sanic.dev/en/) are currently supported. + +## Piccolo ecosystem + +### Piccolo Admin + +Piccolo Admin is a powerful admin interface / content management system for Python, built on top of Piccolo. + +It was created at a design agency to serve the needs of customers who demand a high quality, beautiful admin interface for their websites. It's a modern alternative to tools like Wordpress and Django Admin. + +It's built using the latest technologies, with Vue.js on the front end, and a powerful REST backend. + +Some of it's standout features: + +* Powerful data filtering +* Builtin security +* Multi-factor Authentication +* Media support, both locally and in S3 compatible services +* Dark mode support +* CSV exports +* Easily create custom forms +* Works on mobile and desktop +* Use standalone, or integrate with several supported ASGI frameworks +* Multilingual out of box +* Bulk actions, like updating and deleting data +* Flexible UI - only show the columns you want your users to see + +You can read the docs [here](https://piccolo-admin.readthedocs.io/en/latest/). + +### Piccolo API + +Utilities for easily exposing [Piccolo](https://piccolo-orm.readthedocs.io/en/latest/) tables as REST endpoints in ASGI apps, such as [Starlette](https://www.starlette.io) and [FastAPI](https://fastapi.tiangolo.com/). + +Includes a bunch of useful ASGI middleware: + +- Session Auth +- Token Auth +- Rate Limiting +- CSRF +- Content Security Policy (CSP) +- And more + +You can read the docs [here](https://piccolo-api.readthedocs.io/en/latest/). ## Are you a Django user? @@ -97,4 +144,6 @@ We have a handy page which shows the equivalent of [common Django queries in Pic ## Documentation -See [Read the docs](https://piccolo-orm.readthedocs.io/en/latest/piccolo/getting_started/index.html). +Our documentation is on [Read the docs](https://piccolo-orm.readthedocs.io/en/latest/piccolo/getting_started/index.html). + +We also have some great [tutorial videos on YouTube](https://www.youtube.com/channel/UCE7x5nm1Iy9KDfXPNrNQ5lA). diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..ed95cb8a7 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,13 @@ +# Security Policy + +## Supported Versions + +v1 is actively maintained, and any security vulnerabilities will be patched. + +v0.X will have any major security vulnerabilities patched. + +## Reporting a Vulnerability + +We recommend opening a security advisory on GitHub, as per the [documentation](https://docs.github.com/en/code-security/security-advisories/guidance-on-reporting-and-writing-information-about-vulnerabilities/privately-reporting-a-security-vulnerability). + +Alternatively, reach out to the maintainers via email (see [setup.py](https://github.com/piccolo-orm/piccolo/blob/bbd2e4ad6378b2080d58fb7c7ed392f0425f0f21/setup.py#L60) for contact details). diff --git a/docs/doc-requirements.txt b/docs/doc-requirements.txt deleted file mode 100644 index 3c254d122..000000000 --- a/docs/doc-requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -Sphinx==3.2.1 -sphinx-rtd-theme==0.5.0 -livereload==2.6.3 diff --git a/docs/logo_hero.png b/docs/logo_hero.png new file mode 100644 index 000000000..fde37ca2f Binary files /dev/null and b/docs/logo_hero.png differ diff --git a/docs/serve_docs.py b/docs/serve_docs.py deleted file mode 100755 index 50ed6a05f..000000000 --- a/docs/serve_docs.py +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env python -from livereload import Server, shell - - -server = Server() -server.watch( - 'src/', - shell('make html') -) -server.watch( - '../piccolo', - shell('make html') -) -server.serve(root='build/html') diff --git a/docs/src/conf.py b/docs/src/conf.py index dee72d93c..fd5f0d2c0 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -15,7 +15,6 @@ import datetime import os import sys -import typing as t sys.path.insert(0, os.path.abspath("../..")) @@ -34,23 +33,10 @@ # -- General configuration --------------------------------------------------- -# If your documentation needs a minimal Sphinx version, state it here. -# -# needs_sphinx = '1.0' - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. extensions = [ - "sphinx.ext.autodoc", - "sphinx.ext.todo", "sphinx.ext.coverage", - "sphinx.ext.githubpages", ] -# Add any paths that contain templates here, relative to this directory. -templates_path = ["_templates"] - # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # @@ -60,70 +46,40 @@ # The master toctree document. master_doc = "index" -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# -# This is also used if you do content translation via gettext catalogs. -# Usually you set "language" from the command line for these cases. -language = None - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -exclude_patterns: t.List[str] = [] +# -- Intersphinx ------------------------------------------------------------- -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = None +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), + "piccolo_api": ("https://piccolo-api.readthedocs.io/en/latest/", None), +} +extensions += ["sphinx.ext.intersphinx"] # -- Autodoc ----------------------------------------------------------------- +extensions += ["sphinx.ext.autodoc"] autodoc_typehints = "signature" +autodoc_typehints_format = "short" +autoclass_content = "both" # -- Options for HTML output ------------------------------------------------- -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = "sphinx_rtd_theme" - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -# -# html_theme_options = {} - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ["_static"] - -# Custom sidebar templates, must be a dictionary that maps document names -# to template names. -# -# The default sidebars (for documents that don't match any pattern) are -# defined by theme itself. Builtin themes are using these templates by -# default: ``['localtoc.html', 'relations.html', 'sourcelink.html', -# 'searchbox.html']``. -# -# html_sidebars = {} - +html_theme = "piccolo_theme" +html_short_title = "Piccolo" +html_show_sphinx = False +globaltoc_maxdepth = 3 +html_theme_options = { + "source_url": "https://github.com/piccolo-orm/piccolo/", + "banner_text": 'Piccolo Admin now supports Multi-factor Authentication!', # noqa : E501 + "banner_hiding": "permanent", +} # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. htmlhelp_basename = "Piccolodoc" - # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [(master_doc, "piccolo", "Piccolo Documentation", [author], 1)] - - -# -- Extension configuration ------------------------------------------------- - -# -- Options for todo extension ---------------------------------------------- - -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = True diff --git a/docs/src/index.rst b/docs/src/index.rst index b38222dd6..82b6bd690 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -1,5 +1,13 @@ -Welcome to Piccolo's documentation! -=================================== +.. note:: These are the docs for **Piccolo v1**. :ref:`Read more here `. + + For v0.x docs `go here `_. + + +Piccolo +======= + +Piccolo is a modern, async query builder and ORM for Python, with lots of +batteries included. .. toctree:: :maxdepth: 1 @@ -8,18 +16,24 @@ Welcome to Piccolo's documentation! piccolo/getting_started/index piccolo/query_types/index piccolo/query_clauses/index + piccolo/functions/index piccolo/schema/index piccolo/projects_and_apps/index piccolo/engines/index piccolo/migrations/index piccolo/authentication/index piccolo/asgi/index + piccolo/serialization/index + piccolo/testing/index piccolo/features/index piccolo/playground/index - piccolo/deployment/index piccolo/ecosystem/index + piccolo/tutorials/index piccolo/contributing/index piccolo/changes/index + piccolo/help/index + piccolo/api_reference/index + piccolo/v1/index ------------------------------------------------------------------------------- @@ -45,3 +59,14 @@ Give me an ASGI web app! .. code-block:: bash piccolo asgi new + +FastAPI, Starlette, BlackSheep, Litestar, Ravyn, Lilya, Quart, Falcon and Sanic +are currently supported, with more coming soon. + +---------------------------------------------------------------------------------- + +Videos +------ + +Piccolo has some `tutorial videos on YouTube `_, +which are a great companion to the docs. diff --git a/docs/src/logo.png b/docs/src/logo.png new file mode 100644 index 000000000..2afc9b781 Binary files /dev/null and b/docs/src/logo.png differ diff --git a/docs/src/piccolo/api_reference/index.rst b/docs/src/piccolo/api_reference/index.rst new file mode 100644 index 000000000..fbe6feaba --- /dev/null +++ b/docs/src/piccolo/api_reference/index.rst @@ -0,0 +1,146 @@ +API reference +============= + +Table +----- + +.. currentmodule:: piccolo.table + +.. autoclass:: Table + :members: + +------------------------------------------------------------------------------- + +SchemaManager +------------- + +.. currentmodule:: piccolo.schema + +.. autoclass:: SchemaManager + :members: + +------------------------------------------------------------------------------- + +Column +------ + +.. currentmodule:: piccolo.columns.base + +.. autoclass:: Column + :members: + + +.. autoclass:: ColumnKwargs + :members: + :undoc-members: + +------------------------------------------------------------------------------- + +Aggregate functions +------------------- + +Count +~~~~~ + +.. currentmodule:: piccolo.query.methods.select + +.. autoclass:: Count + +------------------------------------------------------------------------------- + +Refresh +------- + +.. currentmodule:: piccolo.query.methods.refresh + +.. autoclass:: Refresh + :members: + +------------------------------------------------------------------------------- + +LazyTableReference +------------------ + +.. currentmodule:: piccolo.columns + +.. autoclass:: LazyTableReference + :members: + +------------------------------------------------------------------------------- + +Enums +----- + +Foreign Keys +~~~~~~~~~~~~ + +.. currentmodule:: piccolo.columns + +.. autoclass:: OnDelete + :members: + :undoc-members: + +.. autoclass:: OnUpdate + :members: + :undoc-members: + +.. currentmodule:: piccolo.columns.indexes + +Indexes +~~~~~~~ + +.. autoclass:: IndexMethod + :members: + :undoc-members: + +------------------------------------------------------------------------------- + +Column defaults +--------------- + +.. currentmodule:: piccolo.columns.defaults + +Date +~~~~ + +.. autoclass:: DateOffset + :members: + + +UUID +~~~~ + +.. autoclass:: UUID4 + :members: + +------------------------------------------------------------------------------- + +Testing +------- + +.. currentmodule:: piccolo.testing.model_builder + +ModelBuilder +~~~~~~~~~~~~ + +.. autoclass:: ModelBuilder + :members: + +.. currentmodule:: piccolo.table + +create_db_tables / drop_db_tables +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: create_db_tables +.. autofunction:: create_db_tables_sync +.. autofunction:: drop_db_tables +.. autofunction:: drop_db_tables_sync + +------------------------------------------------------------------------------- + +QueryString +----------- + +.. currentmodule:: piccolo.querystring + +.. autoclass:: QueryString diff --git a/docs/src/piccolo/asgi/index.rst b/docs/src/piccolo/asgi/index.rst index da560d62d..a05c75874 100644 --- a/docs/src/piccolo/asgi/index.rst +++ b/docs/src/piccolo/asgi/index.rst @@ -15,27 +15,34 @@ By using the ``piccolo asgi new`` command, Piccolo will scaffold an ASGI web app for you, which includes everything you need to get started. The command will ask for your preferences on which libraries to use. +------------------------------------------------------------------------------- + Routing frameworks ****************** -Currently, `Starlette `_, `FastAPI `_, -and `BlackSheep `_ are supported. - -Other great ASGI routing frameworks exist, and may be supported in the future -(`Quart `_ , -`Sanic `_ , -`Django `_ etc). +`Starlette `_, `FastAPI `_, +`BlackSheep `_, +`Litestar `_, `Ravyn `_, +`Lilya `_, +`Quart `_, +`Falcon `_ +and `Sanic `_ are supported. Which to use? ============= -All are great choices. FastAPI is built on top of Starlette, so they're -very similar. FastAPI is useful if you want to document a REST API. +All are great choices. FastAPI is built on top of Starlette and Ravyn is built on top of Lilya, so they're +very similar. FastAPI, BlackSheep, Litestar and Ravyn are great if you want to document a REST +API, as they have built-in OpenAPI support. + +------------------------------------------------------------------------------- Web servers ************ -`Hypercorn `_ and -`Uvicorn `_ are available as ASGI servers. +`Uvicorn `_, +`Hypercorn `_ +and `Granian `_ +are available as ASGI servers. `Daphne `_ can't be used programatically so was omitted at this time. diff --git a/docs/src/piccolo/authentication/baseuser.rst b/docs/src/piccolo/authentication/baseuser.rst index a3790382b..c3dfcbbff 100644 --- a/docs/src/piccolo/authentication/baseuser.rst +++ b/docs/src/piccolo/authentication/baseuser.rst @@ -23,17 +23,34 @@ Commands The app comes with some useful commands. -user create -~~~~~~~~~~~ +create +~~~~~~ -Create a new user. +Creates a new user. It presents an interactive prompt, asking for the username, +password etc. .. code-block:: bash piccolo user create -user change_password -~~~~~~~~~~~~~~~~~~~~ +If you'd prefer to create a user without the interactive prompt (perhaps in a +script), you can pass all of the arguments in as follows: + +.. code-block:: bash + + piccolo user create --username=bob --password=bob123 --email=foo@bar.com --is_admin=t --is_superuser=t --is_active=t + +.. warning:: + If you choose this approach then be careful, as the password will be in the + shell's history. + +list +~~~~ + +List existing users. + +change_password +~~~~~~~~~~~~~~~ Change a user's password. @@ -41,8 +58,8 @@ Change a user's password. piccolo user change_password -user change_permissions -~~~~~~~~~~~~~~~~~~~~~~~ +change_permissions +~~~~~~~~~~~~~~~~~~ Change a user's permissions. The options are ``--admin``, ``--superuser`` and ``--active``, which change the corresponding attributes on ``BaseUser``. @@ -53,20 +70,39 @@ For example: piccolo user change_permissions some_user --active=true -The Piccolo Admin (see :ref:`Ecosystem`) uses these attributes to control who +The :ref:`Piccolo Admin` uses these attributes to control who can login and what they can do. - * **active** and **admin** - must be true for a user to be able to login. - * **superuser** - must be true for a user to be able to change other user's - passwords. +* **active** and **admin** - must be true for a user to be able to login. +* **superuser** - must be true for a user to be able to change other user's + passwords. ------------------------------------------------------------------------------- Within your code ---------------- -login -~~~~~ +create_user / create_user_sync +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To create a new user: + +.. code-block:: python + + # From within a coroutine: + await BaseUser.create_user(username="bob", password="abc123", active=True) + + # When not in an event loop: + BaseUser.create_user_sync(username="bob", password="abc123", active=True) + +It saves the user in the database, and returns the created ``BaseUser`` +instance. + +.. note:: It is preferable to use this rather than instantiating and saving + ``BaseUser`` directly, as we add additional validation. + +login / login_sync +~~~~~~~~~~~~~~~~~~ To check a user's credentials, do the following: @@ -93,10 +129,54 @@ To change a user's password: .. code-block:: python # From within a coroutine: - await BaseUser.update_password(username="bob", password="abc123") + await BaseUser.update_password(user="bob", password="abc123") # When not in an event loop: - BaseUser.update_password_sync(username="bob", password="abc123") + BaseUser.update_password_sync(user="bob", password="abc123") .. warning:: Don't use bulk updates for passwords - use ``update_password`` / ``update_password_sync``, and they'll correctly hash the password. + +------------------------------------------------------------------------------- + +Limits +------ + +The maximum password length allowed is 128 characters. This should be +sufficiently long for most use cases. + +The minimum password length allowed is 6 characters. + +------------------------------------------------------------------------------- + +Extending ``BaseUser`` +---------------------- + +If you want to extend ``BaseUser`` with additional fields, we recommend creating +a ``Profile`` table with a ``ForeignKey`` to ``BaseUser``, which can include +any custom fields. + +.. code-block:: python + + from piccolo.apps.user.tables import BaseUser + from piccolo.columns import ForeignKey, Text, Varchar + from piccolo.table import Table + + class Profile(Table): + custom_user = ForeignKey(BaseUser) + phone_number = Varchar() + bio = Text() + +Alternatively, you can copy the entire `user app `_ into your +project, and customise it to fit your needs. + +---------------------------------------------------------------------------------- + +Source +------ + +.. currentmodule:: piccolo.apps.user.tables + +.. autoclass:: BaseUser + :members: create_user, create_user_sync, login, login_sync, update_password, update_password_sync + :class-doc-from: class diff --git a/docs/src/piccolo/authentication/index.rst b/docs/src/piccolo/authentication/index.rst index 731caaf00..48609827a 100644 --- a/docs/src/piccolo/authentication/index.rst +++ b/docs/src/piccolo/authentication/index.rst @@ -3,7 +3,7 @@ Authentication ============== -Piccolo ships with some authentication support out of the box. +Piccolo ships with authentication support out of the box. ------------------------------------------------------------------------------- @@ -22,3 +22,14 @@ Tables :maxdepth: 1 ./baseuser + +------------------------------------------------------------------------------- + +Web app integration +------------------- + +Our sister project, `Piccolo API `_, +contains powerful endpoints and middleware for integrating +`session auth `_ +and `token auth `_ +into your ASGI web application, using ``BaseUser``. diff --git a/docs/src/piccolo/changes/index.rst b/docs/src/piccolo/changes/index.rst index d2f9f30f4..55704f02a 100644 --- a/docs/src/piccolo/changes/index.rst +++ b/docs/src/piccolo/changes/index.rst @@ -1 +1 @@ -.. include:: ../../../../CHANGES +.. include:: ../../../../CHANGES.rst diff --git a/docs/src/piccolo/contributing/index.rst b/docs/src/piccolo/contributing/index.rst index ddf897cfa..7a53e1dd0 100644 --- a/docs/src/piccolo/contributing/index.rst +++ b/docs/src/piccolo/contributing/index.rst @@ -6,37 +6,64 @@ Contributing If you want to dig deeper into the Piccolo internals, follow these instructions. +------------------------------------------------------------------------------- + +Running Cockroach +----------------- + +To get a local Cockroach instance running, you can use: + +.. code-block:: console + + cockroach start-single-node --insecure --store=type=mem,size=2GiB + +Make sure the test database exists: + +.. code-block:: console + + cockroach sql --insecure + >>> create database piccolo + >>> use piccolo + +------------------------------------------------------------------------------- + Get the tests running --------------------- - * Create a new virtualenv - * Clone the `Git repo `_ - * ``cd piccolo`` - * Install default dependencies: ``pip install -r requirements/requirements.txt`` - * Install development dependencies: ``pip install -r requirements/dev-requirements.txt`` - * Install test dependencies: ``pip install -r requirements/test-requirements.txt`` - * Setup Postgres - * Run the automated code linting/formatting tools: ``./scripts/lint.sh`` - * Run the test suite with Postgres: ``./scripts/test-postgres.sh`` - * Run the test suite with Sqlite: ``./scripts/test-sqlite.sh`` +* Create a new virtualenv +* Clone the `Git repo `_ +* ``cd piccolo`` +* Install default dependencies: ``pip install -r requirements/requirements.txt`` +* Install development dependencies: ``pip install -r requirements/dev-requirements.txt`` +* Install test dependencies: ``pip install -r requirements/test-requirements.txt`` +* Install database drivers: ``pip install -r requirements/extras/postgres.txt -r requirements/extras/sqlite.txt`` +* Setup Postgres, and make sure a database called ``piccolo`` exists (see ``tests/postgres_conf.py``). +* Run the automated code linting/formatting tools: ``./scripts/lint.sh`` +* Run the test suite with Postgres: ``./scripts/test-postgres.sh`` +* Run the test suite with Cockroach: ``./scripts/test-cockroach.sh`` +* Run the test suite with Sqlite: ``./scripts/test-sqlite.sh`` + +------------------------------------------------------------------------------- Contributing to the docs ------------------------ The docs are written using Sphinx. To get them running locally: - * ``cd docs`` - * Install the requirements: ``pip install -r doc-requirements.txt`` - * Do an initial build of the docs: ``make html`` - * Serve the docs: ``python serve_docs.py`` - * The docs will auto rebuild as you make changes. +* Install the requirements: ``pip install -r requirements/doc-requirements.txt`` +* ``cd docs`` +* Do an initial build of the docs: ``make html`` +* Serve the docs: ``./scripts/run-docs.sh`` +* The docs will auto rebuild as you make changes. + +------------------------------------------------------------------------------- Code style ---------- Piccolo uses `Black `_ for formatting, preferably with a max line length of 79, to keep it consistent -with `PEP8 `_ . +with `PEP8 `_ . You can configure `VSCode `_ by modifying ``settings.json`` as follows: @@ -55,3 +82,15 @@ You can configure `VSCode `_ by modifying } Type hints are used throughout the project. + +------------------------------------------------------------------------------- + +Profiling +--------- + +This isn't required to contribute to Piccolo, but is useful when investigating +performance problems. + + * Install the dependencies: ``pip install requirements/profile-requirements.txt`` + * Make sure a Postgres database called ``piccolo_profile`` exists. + * Run ``./scripts/profile.sh`` to get performance data. diff --git a/docs/src/piccolo/deployment/index.rst b/docs/src/piccolo/deployment/index.rst deleted file mode 100644 index 1fc595b9b..000000000 --- a/docs/src/piccolo/deployment/index.rst +++ /dev/null @@ -1,16 +0,0 @@ -Deployment -========== - -Docker ------- - -Piccolo has several dependencies which are compiled (e.g. asyncpg, orjson), -which is great for performance, but you may run into difficulties when using -Alpine Linux as your base Docker image. - -Alpine uses a different compiler toolchain to most Linux distros. It's -highly recommended to use Debian as your base Docker image. Many Python packages -have prebuilt versions for Debian, meaning you don't have to compile them at -all during install. The result is a much faster build process, and potentially -even a smaller overall Docker image size (the size of Alpine quickly balloons -after you've added all of the compilation dependencies). diff --git a/docs/src/piccolo/ecosystem/index.rst b/docs/src/piccolo/ecosystem/index.rst index d1bb9a85a..8bcf5aff4 100644 --- a/docs/src/piccolo/ecosystem/index.rst +++ b/docs/src/piccolo/ecosystem/index.rst @@ -7,17 +7,34 @@ Piccolo API ----------- Provides some handy utilities for creating an API around your Piccolo tables. -Examples include easy CRUD endpoints for ASGI apps, authentication and -rate limiting. `Read the docs `_. +Examples include: + +* Easily creating CRUD endpoints for ASGI apps, based on Piccolo tables. +* Automatically creating Pydantic models from your Piccolo tables. +* Great FastAPI integration. +* Authentication and rate limiting. + +`See the docs `_ for +more information. + +------------------------------------------------------------------------------- + +.. _PiccoloAdmin: Piccolo Admin ------------- Lets you create a powerful web GUI for your tables in two minutes. View the -project on `Github `_. +project on `Github `_, and the +`docs `_ for more information. .. image:: https://raw.githubusercontent.com/piccolo-orm/piccolo_admin/master/docs/images/screenshot.png +It's a modern UI built with Vue JS, which supports powerful data filtering, and +CSV exports. It's the crown jewel in the Piccolo ecosystem! + +------------------------------------------------------------------------------- + Piccolo Examples ---------------- diff --git a/docs/src/piccolo/engines/cockroach_engine.rst b/docs/src/piccolo/engines/cockroach_engine.rst new file mode 100644 index 000000000..a3c3cd85f --- /dev/null +++ b/docs/src/piccolo/engines/cockroach_engine.rst @@ -0,0 +1,42 @@ +CockroachEngine +=============== + +Configuration +------------- + +.. code-block:: python + + # piccolo_conf.py + from piccolo.engine.cockroach import CockroachEngine + + + DB = CockroachEngine(config={ + 'host': 'localhost', + 'database': 'piccolo', + 'user': 'root', + 'password': '', + 'port': '26257', + }) + +config +~~~~~~ + +The config dictionary is passed directly to the underlying database adapter, +asyncpg. See the `asyncpg docs `_ +to learn more. + +------------------------------------------------------------------------------- + +Connection Pool +--------------- + +See :ref:`ConnectionPool`. + +------------------------------------------------------------------------------- + +Source +------ + +.. currentmodule:: piccolo.engine.cockroach + +.. autoclass:: CockroachEngine diff --git a/docs/src/piccolo/engines/connection_pool.rst b/docs/src/piccolo/engines/connection_pool.rst new file mode 100644 index 000000000..f5856a917 --- /dev/null +++ b/docs/src/piccolo/engines/connection_pool.rst @@ -0,0 +1,80 @@ +.. _ConnectionPool: + +Connection Pool +=============== + +.. hint:: Connection pools can be used with Postgres and CockroachDB. + +Setup +~~~~~ + +To use a connection pool, you need to first initialise it. The best place to do +this is in the startup event handler of whichever web framework you are using. +We also want to close the connection pool in the shutdown event handler. + +The recommended way for Starlette and FastAPI apps is to use the ``lifespan`` +parameter: + +.. code-block:: python + + from contextlib import asynccontextmanager + from piccolo.engine import engine_finder + from starlette.applications import Starlette + + + @asynccontextmanager + async def lifespan(app: Starlette): + engine = engine_finder() + assert engine + await engine.start_connection_pool() + yield + await engine.close_connection_pool() + + + app = Starlette(lifespan=lifespan) + +In older versions of Starlette and FastAPI, you may need event handlers +instead: + +.. code-block:: python + + from piccolo.engine import engine_finder + from starlette.applications import Starlette + + + app = Starlette() + + + @app.on_event('startup') + async def open_database_connection_pool(): + engine = engine_finder() + await engine.start_connection_pool() + + + @app.on_event('shutdown') + async def close_database_connection_pool(): + engine = engine_finder() + await engine.close_connection_pool() + +.. hint:: Using a connection pool helps with performance, since connections + are reused instead of being created for each query. + +Once a connection pool has been started, the engine will use it for making +queries. + +.. hint:: If you're running several instances of an app on the same server, + you may prefer an external connection pooler - like pgbouncer. + +------------------------------------------------------------------------------- + +Configuration +~~~~~~~~~~~~~ + +The connection pool uses the same configuration as your engine. You can also +pass in additional parameters, which are passed to the underlying database +adapter. Here's an example: + +.. code-block:: python + + # To increase the number of connections available: + await engine.start_connection_pool(max_size=20) \ No newline at end of file diff --git a/docs/src/piccolo/engines/index.rst b/docs/src/piccolo/engines/index.rst index 031cf3eb5..db655c76b 100644 --- a/docs/src/piccolo/engines/index.rst +++ b/docs/src/piccolo/engines/index.rst @@ -4,7 +4,7 @@ Engines ======= Engines are what execute the SQL queries. Each supported backend has its own -engine (see  :ref:`EngineTypes`). +:ref:`engine `. It's important that each ``Table`` class knows which engine to use. There are two ways of doing this - setting it explicitly via the ``db`` argument, or @@ -58,9 +58,11 @@ Here's an example ``piccolo_conf.py`` file: DB = SQLiteEngine(path='my_db.sqlite') -.. hint:: A good place for your piccolo_conf file is at the root of your +.. hint:: A good place for your ``piccolo_conf.py`` file is at the root of your project, where the Python interpreter will be launched. +.. _PICCOLO_CONF: + PICCOLO_CONF environment variable ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -73,7 +75,7 @@ In your terminal: export PICCOLO_CONF=piccolo_conf_test -Or at the entypoint for your app, before any other imports: +Or at the entrypoint of your app, before any other imports: .. code-block:: python @@ -82,10 +84,11 @@ Or at the entypoint for your app, before any other imports: This is helpful during tests - you can specify a different configuration file -which contains the connection details for a test database. Similarly, -it's useful if you're deploying your code to different environments (e.g. -staging and production). Have two configuration files, and set the environment -variable accordingly. +which contains the connection details for a test database. + +.. hint:: Piccolo has a builtin command which will do this for you - + automatically setting ``PICCOLO_CONF`` for the duration of your tests. See + :ref:`TesterApp`. .. code-block:: python @@ -95,6 +98,11 @@ variable accordingly. DB = SQLiteEngine(path='my_test_db.sqlite') + +It's also useful if you're deploying your code to different environments (e.g. +staging and production). Have two configuration files, and set the environment +variable accordingly. + If the ``piccolo_conf.py`` file is located in a sub-module (rather than the root of your project) you can specify the path like this: @@ -118,3 +126,5 @@ Engine types ./sqlite_engine ./postgres_engine + ./cockroach_engine + ./connection_pool diff --git a/docs/src/piccolo/engines/postgres_engine.rst b/docs/src/piccolo/engines/postgres_engine.rst index 16a9f6bac..6292c2d71 100644 --- a/docs/src/piccolo/engines/postgres_engine.rst +++ b/docs/src/piccolo/engines/postgres_engine.rst @@ -26,55 +26,10 @@ to learn more. ------------------------------------------------------------------------------- -Connection pool +Connection Pool --------------- -To use a connection pool, you need to first initialise it. The best place to do -this is in the startup event handler of whichever web framework you are using. - -Here's an example using Starlette. Notice that we also close the connection -pool in the shutdown event handler. - -.. code-block:: python - - from piccolo.engine import engine_finder - from starlette.applications import Starlette - - - app = Starlette() - - - @app.on_event('startup') - async def open_database_connection_pool(): - engine = engine_finder() - await engine.start_connection_pool() - - - @app.on_event('shutdown') - async def close_database_connection_pool(): - engine = engine_finder() - await engine.close_connection_pool() - -.. hint:: Using a connection pool helps with performance, since connections - are reused instead of being created for each query. - -Once a connection pool has been started, the engine will use it for making -queries. - -.. hint:: If you're running several instances of an app on the same server, - you may prefer an external connection pooler - like pgbouncer. - -Configuration -~~~~~~~~~~~~~ - -The connection pool uses the same configuration as your engine. You can also -pass in additional parameters, which are passed to the underlying database -adapter. Here's an example: - -.. code-block:: python - - # To increase the number of connections available: - await engine.start_connection_pool(max_size=20) +See :ref:`ConnectionPool`. ------------------------------------------------------------------------------- diff --git a/docs/src/piccolo/engines/sqlite_engine.rst b/docs/src/piccolo/engines/sqlite_engine.rst index 7397edca5..bcd5869e9 100644 --- a/docs/src/piccolo/engines/sqlite_engine.rst +++ b/docs/src/piccolo/engines/sqlite_engine.rst @@ -23,3 +23,11 @@ Source .. currentmodule:: piccolo.engine.sqlite .. autoclass:: SQLiteEngine + +------------------------------------------------------------------------------- + +Production tips +--------------- + +If you're planning on using SQLite in production with Piccolo, with lots of +concurrent queries, then here are some :ref:`useful tips `. diff --git a/docs/src/piccolo/features/index.rst b/docs/src/piccolo/features/index.rst index d1d4b015e..0be80ebef 100644 --- a/docs/src/piccolo/features/index.rst +++ b/docs/src/piccolo/features/index.rst @@ -4,7 +4,6 @@ Features .. toctree:: :maxdepth: 1 - ./tab_completion - ./supported_databases + ./types_and_tab_completion ./security ./syntax diff --git a/docs/src/piccolo/features/security.rst b/docs/src/piccolo/features/security.rst index dddd24f2f..bef04cb6d 100644 --- a/docs/src/piccolo/features/security.rst +++ b/docs/src/piccolo/features/security.rst @@ -6,7 +6,7 @@ Security SQL Injection protection ------------------------ -If you look under the hood, Piccolo uses a custom class called `QueryString` +If you look under the hood, Piccolo uses a custom class called ``QueryString`` for composing queries. It keeps query parameters separate from the query string, so we can pass parameterised queries to the engine. This helps prevent SQL Injection attacks. diff --git a/docs/src/piccolo/features/supported_databases.rst b/docs/src/piccolo/features/supported_databases.rst deleted file mode 100644 index fa022d3f2..000000000 --- a/docs/src/piccolo/features/supported_databases.rst +++ /dev/null @@ -1,12 +0,0 @@ -Supported Databases -=================== - -Postgres --------- -Postgres is the primary focus for Piccolo, and is what we expect most people -will be using in production. - -SQLite ------- -SQLite support is not as complete as Postgres, but it is available - mostly -because it's easy to setup. diff --git a/docs/src/piccolo/features/syntax.rst b/docs/src/piccolo/features/syntax.rst index ee0fb3ba0..82deeae95 100644 --- a/docs/src/piccolo/features/syntax.rst +++ b/docs/src/piccolo/features/syntax.rst @@ -9,8 +9,10 @@ closely as possible. For example: - * In other ORMs, you define models - in Piccolo you define tables. - * Rather than using a filter method, you use a `where` method like in SQL. +* In other ORMs, you define models - in Piccolo you define tables. +* Rather than using a filter method, you use a `where` method like in SQL. + +------------------------------------------------------------------------------- Get the SQL at any time ----------------------- diff --git a/docs/src/piccolo/features/tab_completion.rst b/docs/src/piccolo/features/tab_completion.rst deleted file mode 100644 index 674f05f1a..000000000 --- a/docs/src/piccolo/features/tab_completion.rst +++ /dev/null @@ -1,9 +0,0 @@ -.. _tab_completion: - -Tab Completion -============== - -Piccolo does everything possible to support tab completion. It has been tested -with iPython and VSCode. - -To find out more about how it was done, read `this article `_. diff --git a/docs/src/piccolo/features/types_and_tab_completion.rst b/docs/src/piccolo/features/types_and_tab_completion.rst new file mode 100644 index 000000000..ed7d1158f --- /dev/null +++ b/docs/src/piccolo/features/types_and_tab_completion.rst @@ -0,0 +1,49 @@ +.. _tab_completion: + +Types and Tab Completion +======================== + +Type annotations +---------------- + +The Piccolo codebase uses type annotations extensively. This means it has great +tab completion support in tools like iPython and VSCode. + +It also means it works well with type checkers like Mypy. + +To learn more about how Piccolo achieves this, read this `article about type annotations `_, +and this `article about descriptors `_. + +------------------------------------------------------------------------------- + +Troubleshooting +--------------- + +Here are some issues you may encounter when using Mypy, or another type +checker. + +``id`` column doesn't exist +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you don't explicitly declare a column on your table with ``primary_key=True``, +Piccolo creates a ``Serial`` column for you called ``id``. + +In the following situation, the type checker might complains that ``id`` +doesn't exist: + +.. code-block:: python + + await Band.select(Band.id) + +You can fix this as follows: + +.. code-block:: python + + # tables.py + from piccolo.table import Table + from piccolo.columns.column_types import Serial, Varchar + + + class Band(Table): + id: Serial # Add an annotation + name = Varchar() diff --git a/docs/src/piccolo/functions/aggregate.rst b/docs/src/piccolo/functions/aggregate.rst new file mode 100644 index 000000000..3b1a95a4e --- /dev/null +++ b/docs/src/piccolo/functions/aggregate.rst @@ -0,0 +1,34 @@ +Aggregate functions +=================== + +.. currentmodule:: piccolo.query.functions.aggregate + +Avg +--- + +.. autoclass:: Avg + :class-doc-from: class + +Count +----- + +.. autoclass:: Count + :class-doc-from: class + +Min +--- + +.. autoclass:: Min + :class-doc-from: class + +Max +--- + +.. autoclass:: Max + :class-doc-from: class + +Sum +--- + +.. autoclass:: Sum + :class-doc-from: class diff --git a/docs/src/piccolo/functions/array.rst b/docs/src/piccolo/functions/array.rst new file mode 100644 index 000000000..aa457210a --- /dev/null +++ b/docs/src/piccolo/functions/array.rst @@ -0,0 +1,29 @@ +Array functions +=============== + +.. currentmodule:: piccolo.query.functions.array + +ArrayCat +-------- + +.. autoclass:: ArrayCat + +ArrayAppend +----------- + +.. autoclass:: ArrayAppend + +ArrayPrepend +------------ + +.. autoclass:: ArrayPrepend + +ArrayRemove +----------- + +.. autoclass:: ArrayRemove + +ArrayReplace +------------ + +.. autoclass:: ArrayReplace diff --git a/docs/src/piccolo/functions/basic_usage.rst b/docs/src/piccolo/functions/basic_usage.rst new file mode 100644 index 000000000..de45883b0 --- /dev/null +++ b/docs/src/piccolo/functions/basic_usage.rst @@ -0,0 +1,53 @@ +Basic Usage +=========== + +Select queries +-------------- + +Functions can be used in ``select`` queries - here's an example, where we +convert the values to uppercase: + +.. code-block:: python + + >>> from piccolo.query.functions import Upper + + >>> await Band.select( + ... Upper(Band.name, alias="name") + ... ) + + [{"name": "PYTHONISTAS"}] + +Where clauses +------------- + +Functions can also be used in ``where`` clauses. + +.. code-block:: python + + >>> from piccolo.query.functions import Length + + >>> await Band.select( + ... Band.name + ... ).where( + ... Length(Band.name) > 10 + ... ) + + [{"name": "Pythonistas"}] + +Update queries +-------------- + +And even in ``update`` queries: + +.. code-block:: python + + >>> from piccolo.query.functions import Upper + + >>> await Band.update( + ... {Band.name: Upper(Band.name)}, + ... force=True + ... ).returning(Band.name) + + [{"name": "PYTHONISTAS"}, {"name": "RUSTACEANS"}, {"name": "C-SHARPS"}] + +Pretty much everywhere. diff --git a/docs/src/piccolo/functions/custom.rst b/docs/src/piccolo/functions/custom.rst new file mode 100644 index 000000000..c832c257c --- /dev/null +++ b/docs/src/piccolo/functions/custom.rst @@ -0,0 +1,24 @@ +Custom functions +================ + +If there's a database function which Piccolo doesn't provide out of the box, +you can still easily access it by using :class:`QueryString ` +directly. + +QueryString +----------- + +:class:`QueryString ` is the building block of +queries in Piccolo. + +If we have a custom function defined in the database called ``slugify``, you +can access it like this: + +.. code-block:: python + + from piccolo.querystring import QueryString + + await Band.select( + Band.name, + QueryString('slugify({})', Band.name, alias='name_slug') + ) diff --git a/docs/src/piccolo/functions/datetime.rst b/docs/src/piccolo/functions/datetime.rst new file mode 100644 index 000000000..ee33e8c4a --- /dev/null +++ b/docs/src/piccolo/functions/datetime.rst @@ -0,0 +1,67 @@ +Datetime functions +================== + +.. currentmodule:: piccolo.query.functions.datetime + +Postgres / Cockroach +-------------------- + +Extract +~~~~~~~ + +.. autoclass:: Extract + + +SQLite +------ + +Strftime +~~~~~~~~ + +.. autoclass:: Strftime + + +Database agnostic +----------------- + +These convenience functions work consistently across database engines. + +They all work very similarly, for example: + +.. code-block:: python + + >>> from piccolo.query.functions import Year + >>> await Concert.select( + ... Year(Concert.starts, alias="start_year") + ... ) + [{"start_year": 2024}] + +Year +~~~~ + +.. autofunction:: Year + +Month +~~~~~ + +.. autofunction:: Month + +Day +~~~ + +.. autofunction:: Day + +Hour +~~~~ + +.. autofunction:: Hour + +Minute +~~~~~~ + +.. autofunction:: Minute + +Second +~~~~~~ + +.. autofunction:: Second diff --git a/docs/src/piccolo/functions/index.rst b/docs/src/piccolo/functions/index.rst new file mode 100644 index 000000000..4f702a1c8 --- /dev/null +++ b/docs/src/piccolo/functions/index.rst @@ -0,0 +1,20 @@ +Functions +========= + +.. hint:: + This is an advanced topic - if you're new to Piccolo you can skip this for + now. + +Functions can be used to modify how queries are run, and what is returned. + +.. toctree:: + :maxdepth: 1 + + ./basic_usage + ./aggregate + ./array + ./datetime + ./math + ./string + ./type_conversion + ./custom diff --git a/docs/src/piccolo/functions/math.rst b/docs/src/piccolo/functions/math.rst new file mode 100644 index 000000000..6b9472764 --- /dev/null +++ b/docs/src/piccolo/functions/math.rst @@ -0,0 +1,28 @@ +Math functions +============== + +.. currentmodule:: piccolo.query.functions.math + +Abs +--- + +.. autoclass:: Abs + :class-doc-from: class + +Ceil +---- + +.. autoclass:: Ceil + :class-doc-from: class + +Floor +----- + +.. autoclass:: Floor + :class-doc-from: class + +Round +----- + +.. autoclass:: Round + :class-doc-from: class diff --git a/docs/src/piccolo/functions/string.rst b/docs/src/piccolo/functions/string.rst new file mode 100644 index 000000000..8d991c956 --- /dev/null +++ b/docs/src/piccolo/functions/string.rst @@ -0,0 +1,45 @@ +String functions +================ + +.. currentmodule:: piccolo.query.functions.string + +Concat +------ + +.. autoclass:: Concat + +Length +------ + +.. autoclass:: Length + :class-doc-from: class + +Lower +----- + +.. autoclass:: Lower + :class-doc-from: class + +Ltrim +----- + +.. autoclass:: Ltrim + :class-doc-from: class + +Reverse +------- + +.. autoclass:: Reverse + :class-doc-from: class + +Rtrim +----- + +.. autoclass:: Rtrim + :class-doc-from: class + +Upper +----- + +.. autoclass:: Upper + :class-doc-from: class diff --git a/docs/src/piccolo/functions/type_conversion.rst b/docs/src/piccolo/functions/type_conversion.rst new file mode 100644 index 000000000..7e2743d2b --- /dev/null +++ b/docs/src/piccolo/functions/type_conversion.rst @@ -0,0 +1,25 @@ +Type conversion functions +========================= + +Cast +---- + +.. currentmodule:: piccolo.query.functions.type_conversion + +.. autoclass:: Cast + +Notes on databases +------------------ + +Postgres and CockroachDB have very rich type systems, and you can convert +between most types. SQLite is more limited. + +The following query will work in Postgres / Cockroach, but you might get +unexpected results in SQLite, because it doesn't have a native ``TIME`` column +type: + +.. code-block:: python + + >>> from piccolo.columns import Time + >>> from piccolo.query.functions import Cast + >>> await Concert.select(Cast(Concert.starts, Time())) diff --git a/docs/src/piccolo/getting_started/database_support.rst b/docs/src/piccolo/getting_started/database_support.rst index 0ce384ce0..1106cd350 100644 --- a/docs/src/piccolo/getting_started/database_support.rst +++ b/docs/src/piccolo/getting_started/database_support.rst @@ -4,8 +4,22 @@ Database Support ================ `Postgres `_ is the primary database which Piccolo -was designed for. +was designed for. It's robust, feature rich, and a great choice for most projects. -Limited `SQLite `_ support is available, -mostly to enable tooling like the :ref:`playground `. Postgres is the only database we -recommend for use in production with Piccolo. +`CockroachDB `_ is also supported. It's designed +to be scalable and fault tolerant, and is mostly compatible with Postgres. +There may be some minor features not supported, but it's OK to use. + +`SQLite `_ support was originally added to +enable tooling like the :ref:`playground `, but over time we've +added more and more support. Many people successfully use SQLite and Piccolo +together in production. The main missing feature is support for +:ref:`automatic database migrations ` due to SQLite's limited +support for ``ALTER TABLE`` ``DDL`` statements. + +What about other databases? +--------------------------- + +Our focus is on providing great support for a limited number of databases +(especially Postgres), however it's likely that we'll support more databases in +the future. diff --git a/docs/src/piccolo/getting_started/example_schema.rst b/docs/src/piccolo/getting_started/example_schema.rst index d55fa39f3..1a33a418c 100644 --- a/docs/src/piccolo/getting_started/example_schema.rst +++ b/docs/src/piccolo/getting_started/example_schema.rst @@ -3,7 +3,10 @@ Example Schema ============== -This is the schema used by the example queries throughout the docs. +This is the schema used by the example queries throughout the docs, and also +in the :ref:`playground`. + +``Manager`` and ``Band`` are most commonly used: .. code-block:: python @@ -20,4 +23,30 @@ This is the schema used by the example queries throughout the docs. manager = ForeignKey(references=Manager) popularity = Integer() +We sometimes use these other tables in the examples too: + +.. code-block:: python + + class Venue(Table): + name = Varchar() + capacity = Integer() + + + class Concert(Table): + band_1 = ForeignKey(references=Band) + band_2 = ForeignKey(references=Band) + venue = ForeignKey(references=Venue) + starts = Timestamp() + duration = Interval() + + + class Ticket(Table): + concert = ForeignKey(references=Concert) + price = Numeric() + + + class RecordingStudio(Table): + name = Varchar() + facilities = JSONB() + To understand more about defining your own schemas, see :ref:`DefiningSchema`. diff --git a/docs/src/piccolo/getting_started/index.rst b/docs/src/piccolo/getting_started/index.rst index 3ce421721..373cad786 100644 --- a/docs/src/piccolo/getting_started/index.rst +++ b/docs/src/piccolo/getting_started/index.rst @@ -3,11 +3,14 @@ Getting Started .. toctree:: :caption: Contents: + :maxdepth: 1 ./what_is_piccolo ./database_support ./installing_piccolo ./playground ./setup_postgres - ./sync_and_async + ./setup_cockroach + ./setup_sqlite ./example_schema + ./sync_and_async diff --git a/docs/src/piccolo/getting_started/installing_piccolo.rst b/docs/src/piccolo/getting_started/installing_piccolo.rst index 0354a217b..c5cad56d3 100644 --- a/docs/src/piccolo/getting_started/installing_piccolo.rst +++ b/docs/src/piccolo/getting_started/installing_piccolo.rst @@ -6,22 +6,24 @@ Python You need `Python 3.7 `_ or above installed on your system. +------------------------------------------------------------------------------- + Pip --- -Now install piccolo, ideally inside a `virtualenv `_: +Now install Piccolo, ideally inside a `virtualenv `_: -.. code-block:: python +.. code-block:: bash # Optional - creating a virtualenv on Unix: - python3.7 -m venv my_project + python3 -m venv my_project cd my_project source bin/activate # The important bit: pip install piccolo - # Install Piccolo with PostgreSQL driver: + # Install Piccolo with PostgreSQL or CockroachDB driver: pip install 'piccolo[postgres]' # Install Piccolo with SQLite driver: @@ -30,7 +32,15 @@ Now install piccolo, ideally inside a `virtualenv `_ shell. +It will create an :ref:`example schema ` for you, populates it +with data, and launches an `iPython `_ shell. You can follow along with the tutorials without first learning advanced concepts like migrations. @@ -19,7 +20,10 @@ It's a nice place to experiment with querying / inserting / deleting data using Piccolo, no matter how experienced you are. .. warning:: - Each time you launch the playground it flushes out the existing tables and rebuilds them, so don't use it for anything permanent! + Each time you launch the playground it flushes out the existing tables and + rebuilds them, so don't use it for anything permanent! + +------------------------------------------------------------------------------- SQLite ------ @@ -28,11 +32,15 @@ SQLite is used by default, which provides a zero config way of getting started. A ``piccolo.sqlite`` file will get created in the current directory. +------------------------------------------------------------------------------- + Advanced usage --------------- -To see how to use the playground with Postgres, and other advanced usage, see -:ref:`PlaygroundAdvanced`. +To see how to use the playground with Postgres or Cockroach, and other +advanced usage, see :ref:`PlaygroundAdvanced`. + +------------------------------------------------------------------------------- Test queries ------------ @@ -46,10 +54,12 @@ Give these queries a go: .. code-block:: python - Band.select().run_sync() - Band.objects().run_sync() - Band.select(Band.name).run_sync() - Band.select(Band.name, Band.manager.name).run_sync() + await Band.select() + await Band.objects() + await Band.select(Band.name) + await Band.select(Band.name, Band.manager.name) + +------------------------------------------------------------------------------- Tab completion is your friend ----------------------------- diff --git a/docs/src/piccolo/getting_started/setup_cockroach.rst b/docs/src/piccolo/getting_started/setup_cockroach.rst new file mode 100644 index 000000000..96c5187ec --- /dev/null +++ b/docs/src/piccolo/getting_started/setup_cockroach.rst @@ -0,0 +1,68 @@ +.. _setting_up_cockroach: + +############### +Setup Cockroach +############### + +Installation +************ + +Follow the `instructions for your OS `_. + + +Versions +-------- + +We support the latest stable version. + +.. note:: + Features using ``format()`` will be available in v22.2 or higher, but we recommend using the stable version so you can upgrade automatically when it becomes generally available. + + Cockroach is designed to be a "rolling database": Upgrades are as simple as switching out to the next version of a binary (or changing a number in a ``docker-compose.yml``). This has one caveat: You cannot upgrade an "alpha" release. It is best to stay on the latest stable. + + +------------------------------------------------------------------------------- + +Creating a database +******************* + +cockroach sql +------------- + +CockroachDB comes with its own management tooling. + +.. code-block:: bash + + cd ~/wherever/you/installed/cockroachdb + cockroach sql --insecure + +Enter the following: + +.. code-block:: bash + + create database piccolo; + use piccolo; + +Management GUI +-------------- + +CockroachDB comes with its own web-based management GUI available on localhost: http://127.0.0.1:8080/ + +Beekeeper Studio +---------------- + +If you prefer a GUI, Beekeeper Studio is recommended and has an `installer available `_. + + +------------------------------------------------------------------------------- + + +Column Types +************ + +As of this writing, CockroachDB will always convert ``JSON`` to ``JSONB`` and will always report ``INTEGER`` as ``BIGINT``. + +Piccolo will automatically handle these special cases for you, but we recommend being explicit about this to prevent complications in future versions of Piccolo. + +* Use ``JSONB()`` instead of ``JSON()`` +* Use ``BigInt()`` instead of ``Integer()`` diff --git a/docs/src/piccolo/getting_started/setup_postgres.rst b/docs/src/piccolo/getting_started/setup_postgres.rst index 6ef181334..b95ba521b 100644 --- a/docs/src/piccolo/getting_started/setup_postgres.rst +++ b/docs/src/piccolo/getting_started/setup_postgres.rst @@ -79,14 +79,4 @@ DEB packages are available for `Ubuntu `_. - -------------------------------------------------------------------------------- - -What about other databases? -*************************** - -At the moment the focus is on providing the best Postgres experience possible, -along with some SQLite support. Other databases may be supported in the future. +Piccolo is tested on most major Postgres versions (see the `GitHub Actions file `_). diff --git a/docs/src/piccolo/getting_started/setup_sqlite.rst b/docs/src/piccolo/getting_started/setup_sqlite.rst new file mode 100644 index 000000000..f6f47361d --- /dev/null +++ b/docs/src/piccolo/getting_started/setup_sqlite.rst @@ -0,0 +1,28 @@ +.. _set_up_sqlite: + +Setup SQLite +============ + +Installation +------------ + +The good news is SQLite is good to go out of the box with Python. + +Some Piccolo features are only available with newer SQLite versions. + +.. _check_sqlite_version: + +Check version +------------- + +To check which SQLite version you're using, simply open a Python terminal, and +do the following: + +.. code-block:: python + + >>> import sqlite3 + >>> sqlite3.sqlite_version + '3.39.0' + +The easiest way to upgrade your SQLite version is to install the latest version +of Python. diff --git a/docs/src/piccolo/getting_started/sync_and_async.rst b/docs/src/piccolo/getting_started/sync_and_async.rst index b063ed583..12b7e67a7 100644 --- a/docs/src/piccolo/getting_started/sync_and_async.rst +++ b/docs/src/piccolo/getting_started/sync_and_async.rst @@ -3,59 +3,60 @@ Sync and Async ============== -One of the main motivations for making Piccolo was the lack of options for -ORMs which support asyncio. +One of the motivations for making Piccolo was the lack of ORMs and query +builders which support asyncio. -However, you can use Piccolo in synchronous apps as well, whether that be a -WSGI web app, or a data science script. +Piccolo is designed to be async first. However, you can use Piccolo in +synchronous apps as well, whether that be a WSGI web app, or a data science +script. -Sync example ------------- - -.. code-block:: python +------------------------------------------------------------------------------- - from my_schema import Band +Async example +------------- +You can await a query to run it: - def main(): - print(Band.select().run_sync()) +.. code-block:: python + >>> await Band.select(Band.name) + [{'name': 'Pythonistas'}] - if __name__ == '__main__': - main() - -Async example -------------- +Alternatively, you can await a query's ``run`` method: .. code-block:: python - import asyncio - from my_schema import Band + # This makes it extra explicit that a database query is being made: + >>> await Band.select(Band.name).run() + # It also gives you more control over how the query is run. + # For example, if we wanted to bypass the connection pool for some reason: + >>> await Band.select(Band.name).run(in_pool=False) - async def main(): - print(await Band.select().run()) +Using the async version is useful for applications which require high +throughput. Piccolo makes building an ASGI web app really simple - see +:ref:`ASGICommand`. +------------------------------------------------------------------------------- - if __name__ == '__main__': - asyncio.run(main()) +Sync example +------------ -Which to use? -------------- +This lets you execute a query in an application which isn't using asyncio: + +.. code-block:: python -A lot of the time, using the sync version works perfectly fine. Many of the -examples use the sync version. + >>> Band.select(Band.name).run_sync() + [{'name': 'Pythonistas'}] -Using the async version is useful for web applications which require high -throughput, based on `ASGI frameworks `_. -Piccolo makes building an ASGI web app really simple - see :ref:`ASGICommand`. +------------------------------------------------------------------------------- Explicit -------- -By using ``run`` and ``run_sync``, it makes it very explicit when a query is +By using ``await`` and ``run_sync``, it makes it very explicit when a query is actually being executed. -Until you execute one of those methods, you can chain as many methods onto your +Until you execute ``await`` or ``run_sync``, you can chain as many methods onto your query as you like, safe in the knowledge that no database queries are being made. diff --git a/docs/src/piccolo/getting_started/what_is_piccolo.rst b/docs/src/piccolo/getting_started/what_is_piccolo.rst index a7ab67cda..d9f68b680 100644 --- a/docs/src/piccolo/getting_started/what_is_piccolo.rst +++ b/docs/src/piccolo/getting_started/what_is_piccolo.rst @@ -5,10 +5,35 @@ Piccolo is a fast, easy to learn ORM and query builder. Some of it's stand out features are: -* Support for sync and async - see :ref:`SyncAndAsync`. -* A builtin playground, which makes learning a breeze - see :ref:`Playground`. -* Works great with `iPython `_ and - `VSCode `_ - see :ref:`tab_completion`. -* Batteries included - a :ref:`User model and authentication `, :ref:`migrations `, an :ref:`admin `, - and more. +* Support for :ref:`sync and async `. +* A builtin :ref:`playground `, which makes learning a breeze. +* Fully type annotated, with great :ref:`tab completion support ` - it works great with + `iPython `_ and `VSCode `_. +* Batteries included - a :ref:`User model and authentication `, + :ref:`migrations `, an :ref:`admin `, and more. * Templates for creating your own :ref:`ASGI web app `. + +History +------- + +Piccolo was created while working at a design agency, where almost all projects +being undertaken were API driven (often with high traffic), and required +web sockets. The author was naturally interested in the possibilities of :mod:`asyncio`. +Piccolo is built from the ground up with asyncio in mind. Likewise, Piccolo +makes extensive use of :mod:`type annotations `, another innovation in +Python around the time Piccolo was started. + +A really important thing when working at a design agency is having a **great +admin interface**. A huge amount of effort has gone into +`Piccolo Admin `_ +to make something you'd be proud to give to a client. + +A lot of batteries are included because Piccolo is a pragmatic framework +focused on delivering quality, functional apps to customers. This is why we have +templating tools like ``piccolo asgi new`` for getting a web app started +quickly, automatic database migrations for making iteration fast, and lots of +authentication middleware and endpoints for rapidly +`building APIs `_ out of the box. + +Piccolo has been used extensively by the author on professional projects, for +a range of corporate and startup clients. diff --git a/docs/src/piccolo/help/index.rst b/docs/src/piccolo/help/index.rst new file mode 100644 index 000000000..29669a097 --- /dev/null +++ b/docs/src/piccolo/help/index.rst @@ -0,0 +1,5 @@ +Help +==== + +If you have any questions then the best place to ask them is the +`discussions section on our GitHub page `_. diff --git a/docs/src/piccolo/migrations/create.rst b/docs/src/piccolo/migrations/create.rst index 47ae35c36..486c46165 100644 --- a/docs/src/piccolo/migrations/create.rst +++ b/docs/src/piccolo/migrations/create.rst @@ -2,79 +2,238 @@ Creating migrations =================== Migrations are Python files which are used to modify the database schema in a -controlled way. Each migration belongs to a Piccolo app (see :ref:`PiccoloApps`). +controlled way. Each migration belongs to a :ref:`Piccolo app `. You can either manually populate migrations, or allow Piccolo to do it for you -automatically. To create an empty migration: +automatically. + +We recommend using :ref:`auto migrations ` where possible, +as it saves you time. + +------------------------------------------------------------------------------- + +Manual migrations +----------------- + +First, let's create an empty migration: .. code-block:: bash piccolo migrations new my_app -This creates a new migration file in the migrations folder of the app. The -migration filename is a timestamp, which also serves as the migration ID. +This creates a new migration file in the migrations folder of the app. By +default, the migration filename is the name of the app, followed by a timestamp, +but you can rename it to anything you want: .. code-block:: bash piccolo_migrations/ - 2021-08-06T16-22-51-415781.py + my_app_2022_12_06T13_58_23_024723.py + +.. note:: + We changed the naming convention for migration files in version ``0.102.0`` + (previously they were like ``2022-12-06T13-58-23-024723.py``). As mentioned, + the name isn't important - change it to anything you want. The new format + was chosen because a Python file should start with a letter by convention. The contents of an empty migration file looks like this: .. code-block:: python - from piccolo.apps.migrations.auto import MigrationManager + from piccolo.apps.migrations.auto.migration_manager import MigrationManager - ID = '2021-08-06T16:22:51:415781' - VERSION = "0.29.0" # The version of Piccolo used to create it + ID = "2022-02-26T17:38:44:758593" + VERSION = "0.69.2" # The version of Piccolo used to create it DESCRIPTION = "Optional description" async def forwards(): - manager = MigrationManager(migration_id=ID, app_name="my_app", description=DESCRIPTION) + manager = MigrationManager( + migration_id=ID, + app_name="my_app", + description=DESCRIPTION + ) def run(): + # Replace this with something useful: print(f"running {ID}") manager.add_raw(run) return manager -Replace the `run` function with whatever you want the migration to do - +The ``ID`` is very important - it uniquely identifies the migration, and +shouldn't be changed. + +Replace the ``run`` function with whatever you want the migration to do - typically running some SQL. It can be a function or a coroutine. -The golden rule ---------------- +Running raw SQL +~~~~~~~~~~~~~~~ -Never import your tables directly into a migration, and run methods on them. +If you want to run raw SQL within your migration, you can do so as follows: -This is a **bad example**: +.. code-block:: python + + from piccolo.apps.migrations.auto.migration_manager import MigrationManager + from piccolo.table import Table + + + ID = "2025-07-28T09:51:54:296860" + VERSION = "1.27.1" + DESCRIPTION = "Updating each band's popularity" + + + # This is just a dummy table we use to execute raw SQL with: + class RawTable(Table): + pass + + + async def forwards(): + manager = MigrationManager( + migration_id=ID, + app_name="my_app", + description=DESCRIPTION + ) + + ############################################################# + # This will get run when using `piccolo migrations forwards`: + + async def run(): + await RawTable.raw('UPDATE band SET popularity={}', 1000) + + manager.add_raw(run) + + ############################################################# + # If we want to run some code when reversing the migration, + # using `piccolo migrations backwards`: + + async def run_backwards(): + await RawTable.raw('UPDATE band SET popularity={}', 0) + + manager.add_raw_backwards(run_backwards) + + ############################################################# + # We must always return the MigrationManager: + + return manager + +.. hint:: You can learn more about :ref:`raw queries here `. + +Using your ``Table`` classes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In the above example, we executed raw SQL, but what if we wanted to use the +``Table`` classes from our project instead? + +We have to be quite careful with this. Here's an example: .. code-block:: python - from ..tables import Band + from piccolo.apps.migrations.auto.migration_manager import MigrationManager + + # We're importing a table from our project: + from music.tables import Band - ID = '2021-08-06T16:22:51:415781' - VERSION = "0.29.0" # The version of Piccolo used to create it - DESCRIPTION = "Optional description" + + ID = "2025-07-28T09:51:54:296860" + VERSION = "1.27.1" + DESCRIPTION = "Updating each band's popularity" async def forwards(): - manager = MigrationManager(migration_id=ID) + manager = MigrationManager( + migration_id=ID, + app_name="my_app", + description=DESCRIPTION + ) async def run(): - await Band.create_table().run() + await Band.update({Band.popularity: 1000}, force=True) manager.add_raw(run) return manager -The reason you don't want to do this, is your tables will change over time. If -someone runs your migrations in the future, they will get different results. -Make your migrations completely independent of other code, so they're -self contained and repeatable. +We want our migrations to be repeatable - so if someone runs them a year from +now, they will get the same results. + +By directly importing our tables, we have the following risks: + +* If the ``Band`` class is deleted from the codebase, it could break old + migrations. +* If we modify the ``Band`` class, perhaps by removing columns, this could also + break old migrations. + +Try and make your migration files independent of other application code, so +they're self contained and repeatable. Even though it goes against `DRY `_, +it's better to copy the relevant tables into your migration file: + +.. code-block:: python + + from piccolo.apps.migrations.auto.migration_manager import MigrationManager + from piccolo.columns.column_types import Integer + from piccolo.table import Table + + + ID = "2025-07-28T09:51:54:296860" + VERSION = "1.27.1" + DESCRIPTION = "Updating each band's popularity" + + + # We defined the table within the file, rather than importing it. + class Band(Table): + popularity = Integer() + + + async def forwards(): + manager = MigrationManager( + migration_id=ID, + app_name="my_app", + description=DESCRIPTION + ) + + async def run(): + await Band.update({Band.popularity: 1000}, force=True) + + manager.add_raw(run) + return manager + +Another alternative is to use the ``MigrationManager.get_table_from_snapshot`` +method to get a table from the migration history. This is very convenient, +especially if the table is large, with many foreign keys. + +.. code-block:: python + + from piccolo.apps.migrations.auto.migration_manager import MigrationManager + + + ID = "2025-07-28T09:51:54:296860" + VERSION = "1.27.1" + DESCRIPTION = "Updating each band's popularity" + + + async def forwards(): + manager = MigrationManager( + migration_id=ID, + app_name="", + description=DESCRIPTION + ) + + async def run(): + # We get a table from the migration history. + Band = await manager.get_table_from_snapshot( + app_name="music", table_class_name="Band" + ) + await Band.update({"popularity": 1000}, force=True) + + manager.add_raw(run) + + return manager ------------------------------------------------------------------------------- +.. _AutoMigrations: + Auto migrations --------------- @@ -83,8 +242,8 @@ supports `auto migrations` which can save a great deal of time. Piccolo will work out which tables to add by comparing previous auto migrations, and your current tables. In order for this to work, you have to register -your app's tables with the `AppConfig` in the piccolo_app.py file at the root -of your app (see :ref:`PiccoloApps`). +your app's tables with the ``AppConfig`` in the ``piccolo_app.py`` file at the +root of your app (see :ref:`PiccoloApps`). Creating an auto migration: @@ -122,5 +281,5 @@ can specify it when creating the migration: piccolo migrations new my_app --auto --desc="Adding name column" -The Piccolo CLI will then use this description where appropriate when dealing -with migrations. +The Piccolo CLI will then use this description when listing migrations, to make +them easier to identify. diff --git a/docs/src/piccolo/migrations/running.rst b/docs/src/piccolo/migrations/running.rst index 29069c8b7..5b00db3f6 100644 --- a/docs/src/piccolo/migrations/running.rst +++ b/docs/src/piccolo/migrations/running.rst @@ -13,22 +13,96 @@ When the migration is run, the forwards function is executed. To do this: piccolo migrations forwards my_app +Multiple apps +~~~~~~~~~~~~~ + +If you have multiple apps you can run them all using: + +.. code-block:: bash + + piccolo migrations forwards all + +Migration table +~~~~~~~~~~~~~~~ + +When running the migrations, Piccolo will automatically create a database table +called ``migration`` if it doesn't already exist. Each time a migration is +succesfully ran, a new row is added to this table. + +.. _FakeMigration: + +Fake +~~~~ + +We can 'fake' running a migration - we record that it ran in the database +without actually running it. + +There are two ways to do this - by passing in the ``--fake`` flag on the +command line: + +.. code-block:: bash + + piccolo migrations forwards my_app 2022-09-04T19:44:09 --fake + +Or by setting ``fake=True`` on the ``MigrationManager`` within the migration +file. + +.. code-block:: python + + async def forwards(): + manager = MigrationManager( + migration_id=ID, + app_name="app", + description=DESCRIPTION, + fake=True + ) + ... + + +This is useful if we started from an existing database using +``piccolo schema generate``, and the initial migration we generated is for +tables which already exist, hence we fake run it. + ------------------------------------------------------------------------------- Reversing migrations -------------------- -To reverse the migration, run this: +To reverse the migration, run the following command, specifying the ID of a +migration: .. code-block:: bash - piccolo migrations backwards 2018-09-04T19:44:09 + piccolo migrations backwards my_app 2022-09-04T19:44:09 + +Piccolo will then reverse the migrations for the given app, starting with the +most recent migration, up to and including the migration with the specified ID. You can try going forwards and backwards a few times to make sure it works as expected. ------------------------------------------------------------------------------- +Preview +------- + +To see the SQL queries of a migration without actually running them, use the +``--preview`` flag. + +This works when running migrations forwards: + +.. code-block:: bash + + piccolo migrations forwards my_app --preview + +Or backwards: + +.. code-block:: bash + + piccolo migrations backwards 2022-09-04T19:44:09 --preview + +------------------------------------------------------------------------------- + Checking migrations ------------------- @@ -37,3 +111,24 @@ You can easily check which migrations have and haven't ran using the following: .. code-block:: bash piccolo migrations check + +------------------------------------------------------------------------------- + +Source +------ + +These are the underlying Python functions which are called, so you can see +all available options. These functions are convered into a CI using +`targ `_. + +.. currentmodule:: piccolo.apps.migrations.commands.forwards + +.. autofunction:: forwards + +.. currentmodule:: piccolo.apps.migrations.commands.backwards + +.. autofunction:: backwards + +.. currentmodule:: piccolo.apps.migrations.commands.check + +.. autofunction:: check diff --git a/docs/src/piccolo/playground/advanced.rst b/docs/src/piccolo/playground/advanced.rst index 8c9320c4e..e8f459b9f 100644 --- a/docs/src/piccolo/playground/advanced.rst +++ b/docs/src/piccolo/playground/advanced.rst @@ -13,7 +13,7 @@ first. Install Postgres ~~~~~~~~~~~~~~~~ -See :ref:`setting_up_postgres`. +See :ref:`the docs on settings up Postgres `. Create database ~~~~~~~~~~~~~~~ @@ -43,3 +43,67 @@ When you have the database setup, you can connect to it as follows: .. code-block:: bash piccolo playground run --engine=postgres + +CockroachDB +----------- + +If you want to use CockroachDB instead of SQLite, you need to create a database +first. + + +Install CockroachDB +~~~~~~~~~~~~~~~~~~~ + +See the `installation guide for your OS `_. + +Create database +~~~~~~~~~~~~~~~ +The playground is for testing and learning purposes only, so you can start a CockroachDB +`single node with the insecure flag `_ +(for non-production testing only) like this: + +.. code-block:: bash + + cockroach start-single-node --insecure + +After that, in a new terminal window, you can create a database like this: + +.. code-block:: bash + + cockroach sql --insecure --execute="DROP DATABASE IF EXISTS piccolo_playground CASCADE;CREATE DATABASE piccolo_playground;" + +By default the playground expects a local database to exist with the following +credentials: + + +.. code-block:: bash + + user: "root" + password: "" + host: "localhost" # or 127.0.0.1 + database: "piccolo_playground" + port: 26257 + + +Connecting +~~~~~~~~~~ + +When you have the database setup, you can connect to it as follows: + +.. code-block:: bash + + piccolo playground run --engine=cockroach + +iPython +------- + +The playground is built on top of iPython. We provide sensible defaults out of +the box for syntax highlighting etc. However, to use your own custom iPython +profile (located in ``~/.ipython``), do the following: + +.. code-block:: bash + + piccolo playground run --ipython_profile + +See the `iPython docs `_ +for more information. diff --git a/docs/src/piccolo/projects_and_apps/images/schema_graph_output.png b/docs/src/piccolo/projects_and_apps/images/schema_graph_output.png new file mode 100644 index 000000000..6bd8deb49 Binary files /dev/null and b/docs/src/piccolo/projects_and_apps/images/schema_graph_output.png differ diff --git a/docs/src/piccolo/projects_and_apps/included_apps.rst b/docs/src/piccolo/projects_and_apps/included_apps.rst index b7d279333..5966c92af 100644 --- a/docs/src/piccolo/projects_and_apps/included_apps.rst +++ b/docs/src/piccolo/projects_and_apps/included_apps.rst @@ -4,10 +4,17 @@ Included Apps Just as you can modularise your own code using :ref:`apps`, Piccolo itself ships with several builtin apps, which provide a lot of its functionality. +------------------------------------------------------------------------------- + Auto includes ------------- -The following are registered with your :ref:`AppRegistry` automatically: +The following are registered with your :ref:`AppRegistry` automatically. + +.. hint:: To find out more about each of these commands you can use the + ``--help`` flag on the command line. For example ``piccolo app new --help``. + +------------------------------------------------------------------------------- app ~~~ @@ -18,6 +25,8 @@ Lets you create new Piccolo apps. See :ref:`PiccoloApps`. piccolo app new +------------------------------------------------------------------------------- + asgi ~~~~ @@ -27,6 +36,94 @@ Lets you scaffold an ASGI web app. See :ref:`ASGICommand`. piccolo asgi new +------------------------------------------------------------------------------- + +.. _Fixtures: + +fixtures +~~~~~~~~ + +Fixtures are used when you want to seed your database with essential data (for +example, country names). + +Once you have created a fixture, it can be used by your colleagues when setting +up an application on their local machines, or when deploying to a new +environment. + +Databases such as Postgres have built-in ways of dumping and restoring data +(via ``pg_dump`` and ``pg_restore``). Some reasons to use the fixtures app +instead: + +* When you want the data to be loadable in a range of database types and + versions. +* Fixtures are stored in JSON, which are a bit friendlier for source control. + +dump +^^^^ + +To dump the data into a new fixture file: + +.. code-block:: bash + + piccolo fixtures dump > fixtures.json + +By default, the fixture contains data from all apps and tables. You can specify +a subset of apps and tables instead, for example: + +.. code-block:: bash + + piccolo fixtures dump --apps=blog --tables=Post > fixtures.json + +Or for multiple apps / tables: + +.. code-block:: bash + + piccolo fixtures dump --apps=blog,shop --tables=Post,Product > fixtures.json + + +load +^^^^ + +To load the fixture: + +.. code-block:: bash + + piccolo fixtures load fixtures.json + +If you load the fixture again, you will get primary key errors because the rows +already exist in the database. But what if we need to run it again, because we +had a typo in our fixture, or were missing some data? We can upsert the data +using ``--on_conflict``. + +There are two options: + +1. ``DO NOTHING`` - if any of the rows already exist in the database, just + leave them as they are, and don't raise an exception. +2. ``DO UPDATE`` - if any of the rows already exist in the database, override + them with the latest data in the fixture file. + +.. code-block:: bash + + # DO NOTHING + piccolo fixtures load fixtures.json --on_conflict='DO NOTHING' + + # DO UPDATE + piccolo fixtures load fixtures.json --on_conflict='DO UPDATE' + +And finally, if you're loading a really large fixture, you can specify the +``chunk_size``. By default, Piccolo inserts up to 1,000 rows at a time, as +the database adapter will complain if a single insert query is too large. So +if your fixture containts 10,000 rows, this will mean 10 insert queries. + +You can tune this number higher or lower if you want (lower if the +table has a lot of columns, or higher if the table has few columns). + +.. code-block:: bash + + piccolo fixtures load fixtures.json --chunk_size=500 + +------------------------------------------------------------------------------- + meta ~~~~ @@ -36,11 +133,15 @@ Tells you which version of Piccolo is installed. piccolo meta version +------------------------------------------------------------------------------- + migrations ~~~~~~~~~~ Lets you create and run migrations. See :ref:`Migrations`. +------------------------------------------------------------------------------- + playground ~~~~~~~~~~ @@ -51,6 +152,8 @@ Lets you learn the Piccolo query syntax, using an example schema. See piccolo playground run +------------------------------------------------------------------------------- + project ~~~~~~~ @@ -60,6 +163,60 @@ Lets you create a new ``piccolo_conf.py`` file. See :ref:`PiccoloProjects`. piccolo project new +.. _SchemaApp: + +------------------------------------------------------------------------------- + +schema +~~~~~~ + +generate +^^^^^^^^ + +Lets you auto generate Piccolo ``Table`` classes from an existing database. +Make sure the credentials in ``piccolo_conf.py`` are for the database you're +interested in, then run the following: + +.. code-block:: bash + + piccolo schema generate > tables.py + +.. warning:: This feature is still a work in progress. However, even in it's + current form it will save you a lot of time. Make sure you check the + generated code to make sure it's correct. + +graph +^^^^^ + +A basic schema visualisation tool. It prints out the contents of a GraphViz dot +file representing your schema. + +.. code-block:: bash + + piccolo schema graph + +You can pipe the output to your clipboard (``piccolo schema graph | pbcopy`` +on a Mac), then paste it into a `website like this `_ +to turn it into an image file. + +Or if you have `Graphviz `_ installed on your +machine, you can do this to create an image file: + +.. code-block:: bash + + piccolo schema graph | dot -Tpdf -o graph.pdf + +Here's an example of a generated image: + +.. image:: ./images/schema_graph_output.png + :target: /_images/schema_graph_output.png + +.. note:: + + There is a `video tutorial on YouTube `__. + +------------------------------------------------------------------------------- + shell ~~~~~ @@ -70,6 +227,12 @@ Launches an iPython shell, and automatically imports all of your registered piccolo shell run +.. note:: + + There is a `video tutorial on YouTube `__. + +------------------------------------------------------------------------------- + sql_shell ~~~~~~~~~ @@ -84,6 +247,61 @@ need to run raw SQL queries on your database. For it to work, the underlying command needs to be on the path (i.e. ``psql`` or ``sqlite3`` depending on which you're using). +.. note:: + + There is a `video tutorial on YouTube `__. + +------------------------------------------------------------------------------- + +.. _TesterApp: + +tester +~~~~~~ + +Launches `pytest `_ , which runs your unit test suite. The +advantage of using this rather than running ``pytest`` directly, is the +``PICCOLO_CONF`` environment variable will automatically be set before the +testing starts, and will be restored to it's initial value once the tests +finish. + +.. code-block:: bash + + piccolo tester run + +Setting the :ref:`PICCOLO_CONF` environment variable means your +code will use the database engine specified in that file for the duration of +the testing. + +By default ``piccolo tester run`` sets ``PICCOLO_CONF`` to +``'piccolo_conf_test'``, meaning that a file called ``piccolo_conf_test.py`` +will be imported. + +Within the ``piccolo_conf_test.py`` file, override the database settings, so it +uses a test database: + +.. code-block:: python + + from piccolo_conf import * + + DB = PostgresEngine( + config={ + "database": "my_app_test" + } + ) + + +If you prefer, you can set a custom ``PICCOLO_CONF`` value: + +.. code-block:: bash + + piccolo tester run --piccolo_conf=my_custom_piccolo_conf + +You can also pass arguments to pytest: + +.. code-block:: bash + + piccolo tester run --pytest_args="-s foo" + ------------------------------------------------------------------------------- Optional includes diff --git a/docs/src/piccolo/projects_and_apps/index.rst b/docs/src/piccolo/projects_and_apps/index.rst index 36f46f25b..c77a77a35 100644 --- a/docs/src/piccolo/projects_and_apps/index.rst +++ b/docs/src/piccolo/projects_and_apps/index.rst @@ -12,3 +12,7 @@ application. ./piccolo_projects ./piccolo_apps ./included_apps + +.. note:: + + There is a `video tutorial on YouTube `_. diff --git a/docs/src/piccolo/projects_and_apps/piccolo_apps.rst b/docs/src/piccolo/projects_and_apps/piccolo_apps.rst index b83cc1df5..36a3b6330 100644 --- a/docs/src/piccolo/projects_and_apps/piccolo_apps.rst +++ b/docs/src/piccolo/projects_and_apps/piccolo_apps.rst @@ -1,8 +1,7 @@ .. _PiccoloApps: -############ Piccolo Apps -############ +============ By leveraging Piccolo apps you can: @@ -12,18 +11,17 @@ By leveraging Piccolo apps you can: ------------------------------------------------------------------------------- -*************** Creating an app -*************** +--------------- Run the following command within your project: .. code-block:: bash - piccolo app new my_app + piccolo app new my_app --register -Where `my_app` is your new app's name. This will create a folder like this: +Where ``my_app`` is your new app's name. This will create a folder like this: .. code-block:: bash @@ -36,7 +34,8 @@ Where `my_app` is your new app's name. This will create a folder like this: It's important to register your new app with the ``APP_REGISTRY`` in -`piccolo_conf.py`. +``piccolo_conf.py``. If you used the ``--register`` flag, then this is done +automatically. Otherwise, add it manually: .. code-block:: python @@ -44,16 +43,27 @@ It's important to register your new app with the ``APP_REGISTRY`` in APP_REGISTRY = AppRegistry(apps=['my_app.piccolo_app']) -Anytime you invoke the `piccolo` command, you will now be able to perform +Anytime you invoke the ``piccolo`` command, you will now be able to perform operations on your app, such as :ref:`Migrations`. +root +~~~~ + +By default the app is created in the current directory. If you want the app to +be in a sub folder, you can use the ``--root`` option: + +.. code-block:: bash + + piccolo app new my_app --register --root=./apps + +The app will then be created in the ``apps`` folder. + ------------------------------------------------------------------------------- -********* AppConfig -********* +--------- -Inside your app's `piccolo_app.py` file is an ``AppConfig`` instance. This is +Inside your app's ``piccolo_app.py`` file is an ``AppConfig`` instance. This is how you customise your app's settings. .. code-block:: python @@ -75,78 +85,50 @@ how you customise your app's settings. APP_CONFIG = AppConfig( app_name='blog', - migrations_folder_path=os.path.join(CURRENT_DIRECTORY, 'piccolo_migrations'), + migrations_folder_path=os.path.join( + CURRENT_DIRECTORY, + 'piccolo_migrations' + ), table_classes=[Author, Post, Category, CategoryToPost], migration_dependencies=[], commands=[] ) app_name -======== +~~~~~~~~ -This is used to identify your app, when using the `piccolo` CLI, for example: +This is used to identify your app, when using the ``piccolo`` CLI, for example: .. code-block:: bash piccolo migrations forwards blog migrations_folder_path -====================== +~~~~~~~~~~~~~~~~~~~~~~ Specifies where your app's migrations are stored. By default, a folder called -`piccolo_migrations` is used. +``piccolo_migrations`` is used. table_classes -============= +~~~~~~~~~~~~~ Use this to register your app's ``Table`` subclasses. This is important for -auto migrations (see :ref:`Migrations`). - -You can register them manually, see the example above, or can use -``table_finder``. - -.. _TableFinder: - -table_finder ------------- - -Instead of manually registering ``Table`` subclasses, you can use -``table_finder`` to automatically import any ``Table`` subclasses from a given -list of modules. - -.. code-block:: python - - from piccolo.conf.apps import table_finder - - APP_CONFIG = AppConfig( - app_name='blog', - migrations_folder_path=os.path.join(CURRENT_DIRECTORY, 'piccolo_migrations'), - table_classes=table_finder(modules=['blog.tables']), - migration_dependencies=[], - commands=[] - ) - -The module path should be from the root of the project (the same directory as -your ``piccolo_conf.py`` file, rather than a relative path). - -You can filter the ``Table`` subclasses returned using tags (see :ref:`TableTags`). - -.. currentmodule:: piccolo.conf.apps - -.. autofunction:: table_finder +:ref:`auto migrations `. +You can register them manually (see the example above), or can use +:ref:`table_finder `. migration_dependencies -====================== +~~~~~~~~~~~~~~~~~~~~~~ Used to specify other Piccolo apps whose migrations need to be run before the current app's migrations. commands -======== +~~~~~~~~ You can register functions and coroutines, which are automatically added to -the `piccolo` CLI. +the ``piccolo`` CLI. The `targ `_ library is used under the hood. It makes it really easy to write command lines tools - just use type annotations @@ -162,7 +144,7 @@ and docstrings. Here's an example: The person to greet. """ - print(name) + print("hello,", name) We then register it with the ``AppConfig``. @@ -180,9 +162,17 @@ And from the command line: .. code-block:: bash >>> piccolo my_app say_hello bob - bob + hello, bob + +If the code contains an error to see more details in the output add a ``--trace`` +flag to the command line. -By convention, store the command definitions in a `commands` folder in your +.. code-block:: bash + + >>> piccolo my_app say_hello bob --trace + + +By convention, store the command definitions in a ``commands`` folder in your app. .. code-block:: bash @@ -199,9 +189,47 @@ for inspiration. ------------------------------------------------------------------------------- -************ +.. _TableFinder: + +table_finder +------------ + +Instead of manually registering ``Table`` subclasses, you can use +``table_finder`` to automatically import any ``Table`` subclasses from a given +list of modules. + +.. code-block:: python + + from piccolo.conf.apps import table_finder + + APP_CONFIG = AppConfig( + app_name='blog', + migrations_folder_path=os.path.join( + CURRENT_DIRECTORY, + 'piccolo_migrations' + ), + table_classes=table_finder(modules=['blog.tables']), + migration_dependencies=[], + commands=[] + ) + +The ``modules`` list can contain absolute paths (e.g. ``'blog.tables'``) or +relative paths (e.g. ``'.tables'``). If relative paths are used, then the +``package`` argument must be passed in (``'blog'`` in this case). + +You can filter the ``Table`` subclasses returned using :ref:`tags `. + +Source +~~~~~~ + +.. currentmodule:: piccolo.conf.apps + +.. autofunction:: table_finder + +------------------------------------------------------------------------------- + Sharing Apps -************ +------------ By breaking up your project into apps, the project becomes more maintainable. You can also share these apps between projects, and they can even be installed diff --git a/docs/src/piccolo/projects_and_apps/piccolo_projects.rst b/docs/src/piccolo/projects_and_apps/piccolo_projects.rst index 8ec13d0c2..f85591ede 100644 --- a/docs/src/piccolo/projects_and_apps/piccolo_projects.rst +++ b/docs/src/piccolo/projects_and_apps/piccolo_projects.rst @@ -7,10 +7,12 @@ A Piccolo project is a collection of apps. ------------------------------------------------------------------------------- +.. _PiccoloConf: + piccolo_conf.py --------------- -A project requires a `piccolo_conf.py` file. To create this file, use the following command: +A project requires a ``piccolo_conf.py`` file. To create this, use the following command: .. code-block:: bash @@ -18,8 +20,39 @@ A project requires a `piccolo_conf.py` file. To create this file, use the follow The file serves two important purposes: - * Contains your database settings - * Is used for registering :ref:`PiccoloApps`. +* Contains your database settings. +* Is used for registering :ref:`PiccoloApps`. + +Location +~~~~~~~~ + +By convention, the ``piccolo_conf.py`` file should be at the root of your project: + +.. code-block:: + + my_project/ + piccolo_conf.py + my_app/ + piccolo_app.py + +This means that when you use the ``piccolo`` CLI from the ``my_project`` +folder it can import ``piccolo_conf.py``. + +If you prefer to keep ``piccolo_conf.py`` in a different location, or to give +it a different name, you can do so using the ``PICCOLO_CONF`` environment +variable (see :ref:`PICCOLO_CONF`). For example: + +.. code-block:: + + my_project/ + conf/ + piccolo_conf_local.py + my_app/ + piccolo_app.py + +.. code-block:: bash + + export PICCOLO_CONF=conf.piccolo_conf_local ------------------------------------------------------------------------------- @@ -56,8 +89,7 @@ Here's an example: DB -- -The DB setting is an ``Engine`` instance. To learn more Engines, see -:ref:`Engines`. +The DB setting is an ``Engine`` instance (see the :ref:`Engine docs `). ------------------------------------------------------------------------------- diff --git a/docs/src/piccolo/query_clauses/as_of.rst b/docs/src/piccolo/query_clauses/as_of.rst new file mode 100644 index 000000000..97527033b --- /dev/null +++ b/docs/src/piccolo/query_clauses/as_of.rst @@ -0,0 +1,25 @@ +.. _as_of: + +as_of +===== + +.. note:: Cockroach only. + +You can use ``as_of`` clause with the following queries: + +* :ref:`Select` +* :ref:`Objects` + +To retrieve historical data from 5 minutes ago: + +.. code-block:: python + + await Band.select().where( + Band.name == 'Pythonistas' + ).as_of('-5min') + +This generates an ``AS OF SYSTEM TIME`` clause. See `documentation `_. + +This clause accepts a wide variety of time and interval `string formats `_. + +This is very useful for performance, as it will reduce transaction contention across a cluster. diff --git a/docs/src/piccolo/query_clauses/batch.rst b/docs/src/piccolo/query_clauses/batch.rst index 2fc80a016..0447ebca8 100644 --- a/docs/src/piccolo/query_clauses/batch.rst +++ b/docs/src/piccolo/query_clauses/batch.rst @@ -6,8 +6,12 @@ batch You can use ``batch`` clauses with the following queries: * :ref:`Objects` +* :ref:`Raw` * :ref:`Select` +Example +------- + By default, a query will return as many rows as you ask it for. The problem is when you have a table containing millions of rows - you might not want to load them all into memory at once. To get around this, you can batch the @@ -20,9 +24,24 @@ responses. async for _batch in batch: print(_batch) -.. note:: ``batch`` is one of the few query clauses which doesn't require - .run() to be used after it in order to execute. ``batch`` effectively - replaces ``run``. +Node +---- + +If you're using ``extra_nodes`` with :class:`PostgresEngine `, +you can specify which node to query: + +.. code-block:: python + + # Returns 100 rows at a time from read_replica_db + async with await Manager.select().batch( + batch_size=100, + node="read_replica_db", + ) as batch: + async for _batch in batch: + print(_batch) + +Synchronous version +------------------- There's currently no synchronous version. However, it's easy enough to achieve: diff --git a/docs/src/piccolo/query_clauses/callback.rst b/docs/src/piccolo/query_clauses/callback.rst new file mode 100644 index 000000000..da2510311 --- /dev/null +++ b/docs/src/piccolo/query_clauses/callback.rst @@ -0,0 +1,68 @@ +.. _callback: + +callback +======== + +You can use ``callback`` clauses with the following queries: + +* :ref:`Select` +* :ref:`Objects` + +Callbacks are used to run arbitrary code after a query completes. + +Callback handlers +----------------- + +A callback handler is a function or coroutine that takes query results as +its only parameter. + +For example, you can automatically print the result of a select query using +``print`` as a callback handler: + +.. code-block:: python + + >>> await Band.select(Band.name).callback(print) + [{'name': 'Pythonistas'}] + +Likewise for an objects query: + +.. code-block:: python + + >>> await Band.objects().callback(print) + [] + +Transforming results +-------------------- + +Callback handlers are able to modify the results of a query by returning a +value. Note that in the previous examples, the queries returned ``None`` since +``print`` itself returns ``None``. + +To modify query results with a custom callback handler: + +.. code-block:: python + + >>> def uppercase_name(band): + return band.name.upper() + + >>> await Band.objects().first().callback(uppercase_name) + 'PYTHONISTAS' + +Multiple callbacks +------------------ + +You can add as many callbacks to a query as you like. This can be done in two +ways. + +Passing a list of callbacks: + +.. code-block:: python + + Band.select(Band.name).callback([handler_a, handler_b]) + +Chaining ``callback`` clauses: + +.. code-block:: python + + Band.select(Band.name).callback(handler_a).callback(handler_b) + diff --git a/docs/src/piccolo/query_clauses/distinct.rst b/docs/src/piccolo/query_clauses/distinct.rst index 9fda873f5..b04705d54 100644 --- a/docs/src/piccolo/query_clauses/distinct.rst +++ b/docs/src/piccolo/query_clauses/distinct.rst @@ -9,7 +9,67 @@ You can use ``distinct`` clauses with the following queries: .. code-block:: python - >>> Band.select(Band.name).distinct().run_sync() + >>> await Band.select(Band.name).distinct() [{'title': 'Pythonistas'}] This is equivalent to ``SELECT DISTINCT name FROM band`` in SQL. + +on +-- + +Using the ``on`` parameter we can create ``DISTINCT ON`` queries. + +.. note:: Postgres and CockroachDB only. For more info, see the `Postgres docs `_. + +If we have the following table: + +.. code-block:: python + + class Album(Table): + band = Varchar() + title = Varchar() + release_date = Date() + +With this data in the database: + +.. csv-table:: Albums + :file: ./distinct/albums.csv + :header-rows: 1 + +To get the latest album for each band, we can do so with a query like this: + +.. code-block:: python + + >>> await Album.select().distinct( + ... on=[Album.band] + ... ).order_by( + ... Album.band + ... ).order_by( + ... Album.release_date, + ... ascending=False + ... ) + + [ + { + 'id': 2, + 'band': 'Pythonistas', + 'title': 'Py album 2022', + 'release_date': '2022-12-01' + }, + { + 'id': 4, + 'band': 'Rustaceans', + 'title': 'Rusty album 2022', + 'release_date': '2022-12-01' + }, + ] + +The first column specified in ``on`` must match the first column specified in +``order_by``, otherwise a :class:`DistinctOnError ` will be raised. + +Source +~~~~~~ + +.. currentmodule:: piccolo.query.mixins + +.. autoclass:: DistinctOnError diff --git a/docs/src/piccolo/query_clauses/distinct/albums.csv b/docs/src/piccolo/query_clauses/distinct/albums.csv new file mode 100644 index 000000000..c259f1154 --- /dev/null +++ b/docs/src/piccolo/query_clauses/distinct/albums.csv @@ -0,0 +1,5 @@ +id,band,title,release_date +1,Pythonistas,Py album 2021,2021-12-01 +2,Pythonistas,Py album 2022,2022-12-01 +3,Rustaceans,Rusty album 2021,2021-12-01 +4,Rustaceans,Rusty album 2022,2022-12-01 diff --git a/docs/src/piccolo/query_clauses/first.rst b/docs/src/piccolo/query_clauses/first.rst index 4a5d9442f..1c04150f1 100644 --- a/docs/src/piccolo/query_clauses/first.rst +++ b/docs/src/piccolo/query_clauses/first.rst @@ -12,14 +12,14 @@ Rather than returning a list of results, just the first result is returned. .. code-block:: python - >>> Band.select().first().run_sync() + >>> await Band.select().first() {'name': 'Pythonistas', 'manager': 1, 'popularity': 1000, 'id': 1} Likewise, with objects: .. code-block:: python - >>> Band.objects().first().run_sync() - + >>> await Band.objects().first() + -If no match is found, then `None` is returned instead. +If no match is found, then ``None`` is returned instead. diff --git a/docs/src/piccolo/query_clauses/freeze.rst b/docs/src/piccolo/query_clauses/freeze.rst index 8c61b13ea..54f7c28e5 100644 --- a/docs/src/piccolo/query_clauses/freeze.rst +++ b/docs/src/piccolo/query_clauses/freeze.rst @@ -5,6 +5,9 @@ freeze You can use the ``freeze`` clause with any query type. +Source +------ + .. currentmodule:: piccolo.query.base .. automethod:: Query.freeze diff --git a/docs/src/piccolo/query_clauses/group_by.rst b/docs/src/piccolo/query_clauses/group_by.rst index a348120ce..516a4ac00 100644 --- a/docs/src/piccolo/query_clauses/group_by.rst +++ b/docs/src/piccolo/query_clauses/group_by.rst @@ -7,8 +7,10 @@ You can use ``group_by`` clauses with the following queries: * :ref:`Select` -It is used in combination with aggregate functions - ``Count`` is currently -supported. +It is used in combination with the :ref:`aggregate functions ` +- for example, ``Count``. + +------------------------------------------------------------------------------- Count ----- @@ -17,21 +19,23 @@ In the following query, we get a count of the number of bands per manager: .. code-block:: python - >>> from piccolo.query.methods.select import Count + >>> from piccolo.query.functions.aggregate import Count - >>> b = Band - >>> b.select( - >>> b.manager.name, - >>> Count(b.manager) - >>> ).group_by( - >>> b.manager - >>> ).run_sync() + >>> await Band.select( + ... Band.manager.name.as_alias('manager_name'), + ... Count(alias='band_count') + ... ).group_by( + ... Band.manager.name + ... ) [ - {"manager.name": "Graydon", "count": 1}, - {"manager.name": "Guido", "count": 1} + {"manager_name": "Graydon", "band_count": 1}, + {"manager_name": "Guido", "band_count": 1} ] -.. currentmodule:: piccolo.query.methods.select +------------------------------------------------------------------------------- + +Other aggregate functions +------------------------- -.. autoclass:: Count +These work the same as ``Count``. See :ref:`aggregate functions `. diff --git a/docs/src/piccolo/query_clauses/index.rst b/docs/src/piccolo/query_clauses/index.rst index 14f6ea96f..feac07b0f 100644 --- a/docs/src/piccolo/query_clauses/index.rst +++ b/docs/src/piccolo/query_clauses/index.rst @@ -6,16 +6,33 @@ Query Clauses Query clauses are used to modify a query by making it more specific, or by modifying the return values. + .. toctree:: - :maxdepth: 0 + :maxdepth: 1 + :caption: Essential ./first - ./distinct - ./group_by ./limit - ./offset ./order_by - ./output ./where + +.. toctree:: + :maxdepth: 1 + :caption: Advanced + ./batch + ./callback + ./distinct ./freeze + ./group_by + ./lock_rows + ./offset + ./on_conflict + ./output + ./returning + +.. toctree:: + :maxdepth: 1 + :caption: CockroachDB + + ./as_of diff --git a/docs/src/piccolo/query_clauses/limit.rst b/docs/src/piccolo/query_clauses/limit.rst index 090f1a6f7..8eee56b52 100644 --- a/docs/src/piccolo/query_clauses/limit.rst +++ b/docs/src/piccolo/query_clauses/limit.rst @@ -13,10 +13,10 @@ number you ask for. .. code-block:: python - Band.select().limit(2).run_sync() + await Band.select().limit(2) Likewise, with objects: .. code-block:: python - Band.objects().limit(2).run_sync() + await Band.objects().limit(2) diff --git a/docs/src/piccolo/query_clauses/lock_rows.rst b/docs/src/piccolo/query_clauses/lock_rows.rst new file mode 100644 index 000000000..17ac134d3 --- /dev/null +++ b/docs/src/piccolo/query_clauses/lock_rows.rst @@ -0,0 +1,132 @@ +.. _lock_rows: + +lock_rows +========= + +You can use the ``lock_rows`` clause with the following queries: + +* :ref:`Objects` +* :ref:`Select` + +It returns a query that locks rows until the end of the transaction, generating a ``SELECT ... FOR UPDATE`` SQL statement or similar with other lock strengths. + +.. note:: Postgres and CockroachDB only. + +------------------------------------------------------------------------------- + +Basic Usage +----------- + +Basic usage without parameters: + +.. code-block:: python + + await Band.select(Band.name == 'Pythonistas').lock_rows() + +Equivalent to: + +.. code-block:: sql + + SELECT ... FOR UPDATE + + +lock_strength +------------- + +The parameter ``lock_strength`` controls the strength of the row lock when performing an operation in PostgreSQL. +The value can be a predefined constant from the ``LockStrength`` enum or one of the following strings (case-insensitive): + +* ``UPDATE`` (default): Acquires an exclusive lock on the selected rows, preventing other transactions from modifying or locking them until the current transaction is complete. +* ``NO KEY UPDATE`` (Postgres only): Similar to ``UPDATE``, but allows other transactions to insert or delete rows that do not affect the primary key or unique constraints. +* ``KEY SHARE`` (Postgres only): Permits other transactions to acquire key-share or share locks, allowing non-key modifications while preventing updates or deletes. +* ``SHARE``: Acquires a shared lock, allowing other transactions to read the rows but not modify or lock them. + +You can specify a different lock strength: + +.. code-block:: python + + await Band.select(Band.name == 'Pythonistas').lock_rows('SHARE') + +Which is equivalent to: + +.. code-block:: sql + + SELECT ... FOR SHARE + + +nowait +------ + +If another transaction has already acquired a lock on one or more selected rows, an exception will be raised instead of +waiting for the other transaction to release the lock. + +.. code-block:: python + + await Band.select(Band.name == 'Pythonistas').lock_rows('UPDATE', nowait=True) + + +skip_locked +----------- + +Ignore locked rows. + +.. code-block:: python + + await Band.select(Band.name == 'Pythonistas').lock_rows('UPDATE', skip_locked=True) + + +of +-- + +By default, if there are many tables in a query (e.g. when joining), all tables will be locked. +Using ``of``, you can specify which tables should be locked. + +.. code-block:: python + + await Band.select().where(Band.manager.name == 'Guido').lock_rows('UPDATE', of=(Band, )) + +------------------------------------------------------------------------------- + +Full example +------------ + +If we have this table: + +.. code-block:: python + + class Concert(Table): + name = Varchar() + tickets_available = Integer() + +And we want to make sure that ``tickets_available`` never goes below 0, we can +do the following: + +.. code-block:: python + + async def book_tickets(ticket_count: int): + async with Concert._meta.db.transaction(): + concert = await Concert.objects().where( + Concert.name == "Awesome Concert" + ).first().lock_rows() + + if concert.tickets_available >= ticket_count: + await concert.update_self({ + Concert.tickets_available: Concert.tickets_available - ticket_count + }) + else: + raise ValueError("Not enough tickets are available!") + +This means that when multiple transactions are running at the same time, it +isn't possible to book more tickets than are available. + +.. note:: + + There is a `video tutorial on YouTube `__. + +------------------------------------------------------------------------------- + +Learn more +---------- + +* `Postgres docs `_ +* `CockroachDB docs `_ diff --git a/docs/src/piccolo/query_clauses/offset.rst b/docs/src/piccolo/query_clauses/offset.rst index c24b1a4a1..66915ad4f 100644 --- a/docs/src/piccolo/query_clauses/offset.rst +++ b/docs/src/piccolo/query_clauses/offset.rst @@ -15,12 +15,12 @@ otherwise the results returned could be different each time. .. code-block:: python - >>> Band.select(Band.name).offset(1).order_by(Band.name).run_sync() + >>> await Band.select(Band.name).offset(1).order_by(Band.name) [{'name': 'Pythonistas'}, {'name': 'Rustaceans'}] Likewise, with objects: .. code-block:: python - >>> Band.objects().offset(1).order_by(Band.name).run_sync() + >>> await Band.objects().offset(1).order_by(Band.name) [Band2, Band3] diff --git a/docs/src/piccolo/query_clauses/on_conflict.rst b/docs/src/piccolo/query_clauses/on_conflict.rst new file mode 100644 index 000000000..c3adfa51c --- /dev/null +++ b/docs/src/piccolo/query_clauses/on_conflict.rst @@ -0,0 +1,231 @@ +.. _on_conflict: + +on_conflict +=========== + +.. hint:: This is an advanced topic, and first time learners of Piccolo + can skip if they want. + +You can use the ``on_conflict`` clause with the following queries: + +* :ref:`Insert` + +Introduction +------------ + +When inserting rows into a table, if a unique constraint fails on one or more +of the rows, then the insertion fails. + +Using the ``on_conflict`` clause, we can instead tell the database to ignore +the error (using ``DO NOTHING``), or to update the row (using ``DO UPDATE``). + +This is sometimes called an **upsert** (update if it already exists else insert). + +Example data +------------ + +If we have the following table: + +.. code-block:: python + + class Band(Table): + name = Varchar(unique=True) + popularity = Integer() + +With this data: + +.. csv-table:: + :file: ./on_conflict/bands.csv + +Let's try inserting another row with the same ``name``, and we'll get an error: + +.. code-block:: python + + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ) + Unique constraint error! + +``DO NOTHING`` +-------------- + +To ignore the error: + +.. code-block:: python + + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO NOTHING" + ... ) + +If we fetch the data from the database, we'll see that it hasn't changed: + +.. code-block:: python + + >>> await Band.select().where(Band.name == "Pythonistas").first() + {'id': 1, 'name': 'Pythonistas', 'popularity': 1000} + + +``DO UPDATE`` +------------- + +Instead, if we want to update the ``popularity``: + +.. code-block:: python + + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO UPDATE", + ... target=Band.name, + ... values=[Band.popularity] + ... ) + +If we fetch the data from the database, we'll see that it was updated: + +.. code-block:: python + + >>> await Band.select().where(Band.name == "Pythonistas").first() + {'id': 1, 'name': 'Pythonistas', 'popularity': 1200} + +``target`` +---------- + +Using the ``target`` argument, we can specify which constraint we're concerned +with. By specifying ``target=Band.name`` we're only concerned with the unique +constraint for the ``band`` column. If you omit the ``target`` argument on +``DO NOTHING`` action, then it works for all constraints on the table. For +``DO UPDATE`` action, ``target`` is mandatory and must be provided. + +.. code-block:: python + :emphasize-lines: 5 + + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO NOTHING", + ... target=Band.name + ... ) + +If you want to target a composite unique constraint, you can do so by passing +in a tuple of columns: + +.. code-block:: python + :emphasize-lines: 5 + + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO NOTHING", + ... target=(Band.name, Band.popularity) + ... ) + +You can also specify the name of a constraint using a string: + +.. code-block:: python + :emphasize-lines: 5 + + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO NOTHING", + ... target='some_constraint' + ... ) + +``values`` +---------- + +This lets us specify which values to update when a conflict occurs. + +By specifying a :class:`Column `, this means that +the new value for that column will be used: + +.. code-block:: python + :emphasize-lines: 6 + + # The new popularity will be 1200. + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO UPDATE", + ... values=[Band.popularity] + ... ) + +Instead, we can specify a custom value using a tuple: + +.. code-block:: python + :emphasize-lines: 6 + + # The new popularity will be 1111. + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO UPDATE", + ... values=[(Band.popularity, 1111)] + ... ) + +If we want to update all of the values, we can use :meth:`all_columns`. + +.. code-block:: python + :emphasize-lines: 5 + + >>> await Band.insert( + ... Band(id=1, name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO UPDATE", + ... values=Band.all_columns() + ... ) + +``where`` +--------- + +This can be used with ``DO UPDATE``. It gives us more control over whether the +update should be made: + +.. code-block:: python + :emphasize-lines: 6 + + >>> await Band.insert( + ... Band(id=1, name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO UPDATE", + ... values=[Band.popularity], + ... where=Band.popularity < 1000 + ... ) + +Multiple ``on_conflict`` clauses +-------------------------------- + +SQLite allows you to specify multiple ``ON CONFLICT`` clauses, but Postgres and +Cockroach don't. + +.. code-block:: python + + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO UPDATE", + ... ... + ... ).on_conflict( + ... action="DO NOTHING", + ... ... + ... ) + +Learn more +---------- + +* `Postgres docs `_ +* `Cockroach docs `_ +* `SQLite docs `_ + +Source +------ + +.. currentmodule:: piccolo.query.methods.insert + +.. automethod:: Insert.on_conflict + +.. autoclass:: OnConflictAction + :members: + :undoc-members: diff --git a/docs/src/piccolo/query_clauses/on_conflict/bands.csv b/docs/src/piccolo/query_clauses/on_conflict/bands.csv new file mode 100644 index 000000000..d796928a1 --- /dev/null +++ b/docs/src/piccolo/query_clauses/on_conflict/bands.csv @@ -0,0 +1,2 @@ +id,name,popularity +1,Pythonistas,1000 diff --git a/docs/src/piccolo/query_clauses/order_by.rst b/docs/src/piccolo/query_clauses/order_by.rst index 17a88b7c3..eb109d2fc 100644 --- a/docs/src/piccolo/query_clauses/order_by.rst +++ b/docs/src/piccolo/query_clauses/order_by.rst @@ -12,27 +12,76 @@ To order the results by a certain column (ascending): .. code-block:: python - b = Band - b.select().order_by( - b.name - ).run_sync() + await Band.select().order_by( + Band.name + ) To order by descending: .. code-block:: python - b = Band - b.select().order_by( - b.name, + await Band.select().order_by( + Band.name, ascending=False - ).run_sync() + ) + +You can specify the column name as a string if you prefer: + +.. code-block:: python + + await Band.select().order_by( + 'name' + ) You can order by multiple columns, and even use joins: .. code-block:: python - b = Band - b.select().order_by( - b.name, - b.manager.name - ).run_sync() + await Band.select().order_by( + Band.name, + Band.manager.name + ) + +------------------------------------------------------------------------------- + +Advanced +-------- + +Ascending and descending +~~~~~~~~~~~~~~~~~~~~~~~~ + +If you want to order by multiple columns, with some ascending, and some +descending, then you can do so using multiple ``order_by`` statements: + + +.. code-block:: python + + await Band.select().order_by( + Band.name, + ).order_by( + Band.popularity, + ascending=False + ) + +``OrderByRaw`` +~~~~~~~~~~~~~~ + +SQL's ``ORDER BY`` clause is surprisingly rich in functionality, and there may +be situations where you want to specify the ``ORDER BY`` explicitly using SQL. +To do this use ``OrderByRaw``. + +In the example below, we are ordering the results randomly: + +.. code-block:: python + + from piccolo.query import OrderByRaw + + await Band.select(Band.name).order_by( + OrderByRaw('random()'), + ) + +The above is equivalent to the following SQL: + +.. code-block:: sql + + SELECT "band"."name" FROM band ORDER BY random() ASC diff --git a/docs/src/piccolo/query_clauses/output.rst b/docs/src/piccolo/query_clauses/output.rst index 73d7a484e..f9161b1cf 100644 --- a/docs/src/piccolo/query_clauses/output.rst +++ b/docs/src/piccolo/query_clauses/output.rst @@ -20,8 +20,8 @@ To return the data as a JSON string: .. code-block:: python - >>> Band.select().output(as_json=True).run_sync() - '[{"name":"Pythonistas","manager":1,"popularity":1000,"id":1},{"name":"Rustaceans","manager":2,"popularity":500,"id":2}]' + >>> await Band.select(Band.name).output(as_json=True) + '[{"name":"Pythonistas"}]' Piccolo can use `orjson `_ for JSON serialisation, which is blazing fast, and can handle most Python types, including dates, @@ -36,25 +36,37 @@ If you're just querying a single column from a database table, you can use .. code-block:: python - >>> Band.select(Band.id).output(as_list=True).run_sync() + >>> await Band.select(Band.id).output(as_list=True) [1, 2] +nested +~~~~~~ + +Output any data from related tables in nested dictionaries. + +.. code-block:: python + + >>> await Band.select(Band.name, Band.manager.name).first().output(nested=True) + {'name': 'Pythonistas', 'manager': {'name': 'Guido'}} + ------------------------------------------------------------------------------- Select and Objects queries -------------------------- +.. _load_json: + load_json ~~~~~~~~~ -If querying JSON or JSONB columns, you can tell Piccolo to deserialise the JSON -values automatically. +If querying :class:`JSON ` or :class:`JSONB ` +columns, you can tell Piccolo to deserialise the JSON values automatically. .. code-block:: python - >>> RecordingStudio.select().output(load_json=True).run_sync() + >>> await RecordingStudio.select().output(load_json=True) [{'id': 1, 'name': 'Abbey Road', 'facilities': {'restaurant': True, 'mixing_desk': True}}] - >>> studio = RecordingStudio.objects().first().output(load_json=True).run_sync() + >>> studio = await RecordingStudio.objects().first().output(load_json=True) >>> studio.facilities {'restaurant': True, 'mixing_desk': True} diff --git a/docs/src/piccolo/query_clauses/returning.rst b/docs/src/piccolo/query_clauses/returning.rst new file mode 100644 index 000000000..6f613e049 --- /dev/null +++ b/docs/src/piccolo/query_clauses/returning.rst @@ -0,0 +1,52 @@ +.. _returning: + +returning +========= + +You can use the ``returning`` clause with the following queries: + +* :ref:`Insert` +* :ref:`Update` +* :ref:`Delete` + +By default, an update query returns an empty list, but using the ``returning`` +clause you can retrieve values from the updated rows. + +.. code-block:: python + + >>> await Band.update({ + ... Band.name: 'Pythonistas Tribute Band' + ... }).where( + ... Band.name == 'Pythonistas' + ... ).returning(Band.id, Band.name) + [{'id': 1, 'name': 'Pythonistas Tribute Band'}] + +Similarly, for an insert query - we can retrieve some of the values from the +inserted rows: + +.. code-block:: python + + >>> await Manager.insert( + ... Manager(name="Maz"), + ... Manager(name="Graydon") + ... ).returning(Manager.id, Manager.name) + + [{'id': 1, 'name': 'Maz'}, {'id': 1, 'name': 'Graydon'}] + +As another example, let's use delete and return the full row(s): + +.. code-block:: python + + >>> await Band.delete().where( + ... Band.name == "Pythonistas" + ... ).returning(*Band.all_columns()) + + [{'id': 1, 'name': 'Pythonistas', 'manager': 1, 'popularity': 1000}] + +By counting the number of elements of the returned list, you can find out +how many rows were affected or processed by the operation. + +.. warning:: This works for all versions of Postgres, but only + `SQLite 3.35.0 `_ and above + support the returning clause. See the :ref:`docs ` on + how to check your SQLite version. diff --git a/docs/src/piccolo/query_clauses/where.rst b/docs/src/piccolo/query_clauses/where.rst index 4137c2cd2..bd88fc633 100644 --- a/docs/src/piccolo/query_clauses/where.rst +++ b/docs/src/piccolo/query_clauses/where.rst @@ -5,6 +5,7 @@ where You can use ``where`` clauses with the following queries: +* :ref:`Count` * :ref:`Delete` * :ref:`Exists` * :ref:`Objects` @@ -13,28 +14,29 @@ You can use ``where`` clauses with the following queries: It allows powerful filtering of your data. +------------------------------------------------------------------------------- + Equal / Not Equal ----------------- .. code-block:: python - b = Band - b.select().where( - b.name == 'Pythonistas' - ).run_sync() + await Band.select().where( + Band.name == 'Pythonistas' + ) .. code-block:: python - b = Band - b.select().where( - b.name != 'Rustaceans' - ).run_sync() + await Band.select().where( + Band.name != 'Rustaceans' + ) -.. hint:: With ``Boolean`` columns, some linters will complain if you write - ``SomeTable.some_column == True`` (because it's more Pythonic to do - ``is True``). To work around this, you can do - ``SomeTable.some_column.eq(True)``. Likewise, with ``!=`` you can use - ``SomeTable.some_column.ne(True)`` +.. hint:: With :class:`Boolean ` columns, + some linters will complain if you write + ``SomeTable.some_column == True`` (because it's more Pythonic to do + ``is True``). To work around this, you can do + ``SomeTable.some_column.eq(True)``. Likewise, with ``!=`` you can use + ``SomeTable.some_column.ne(True)`` ------------------------------------------------------------------------------- @@ -45,105 +47,144 @@ You can use the ``<, >, <=, >=`` operators, which work as you expect. .. code-block:: python - b = Band - b.select().where( - b.popularity >= 100 - ).run_sync() + await Band.select().where( + Band.popularity >= 100 + ) ------------------------------------------------------------------------------- -like / ilike -------------- +``like`` / ``ilike`` +-------------------- The percentage operator is required to designate where the match should occur. .. code-block:: python - b = Band - b.select().where( - b.name.like('Py%') # Matches the start of the string - ).run_sync() + await Band.select().where( + Band.name.like('Py%') # Matches the start of the string + ) + + await Band.select().where( + Band.name.like('%istas') # Matches the end of the string + ) - b.select().where( - b.name.like('%istas') # Matches the end of the string - ).run_sync() + await Band.select().where( + Band.name.like('%is%') # Matches anywhere in the string + ) - b.select().where( - b.name.like('%is%') # Matches anywhere in string - ).run_sync() + await Band.select().where( + Band.name.like('Pythonistas') # Matches the entire string + ) -``ilike`` is identical, except it's case insensitive. +``ilike`` is identical, except it's Postgres specific and case insensitive. ------------------------------------------------------------------------------- -not_like --------- +``not_like`` +------------ Usage is the same as ``like`` excepts it excludes matching rows. .. code-block:: python - b = Band - b.select().where( - b.name.not_like('Py%') - ).run_sync() + await Band.select().where( + Band.name.not_like('Py%') + ) ------------------------------------------------------------------------------- -is_in / not_in --------------- +``is_in`` / ``not_in`` +---------------------- + +You can get all rows with a value contained in the list: + +.. code-block:: python + + await Band.select().where( + Band.name.is_in(['Pythonistas', 'Rustaceans']) + ) + +And all rows with a value not contained in the list: + +.. code-block:: python + + await Band.select().where( + Band.name.not_in(['Terrible Band', 'Awful Band']) + ) + +You can also pass a subquery into the ``is_in`` clause: .. code-block:: python - b = Band - b.select().where( - b.name.is_in(['Pythonistas']) - ).run_sync() + await Band.select().where( + Band.id.is_in( + Concert.select(Concert.band_1).where( + Concert.starts >= datetime.datetime(year=2025, month=1, day=1) + ) + ) + ) + +.. hint:: + In SQL there are often several ways of solving the same problem. You + can also solve the above using :meth:`join_on `. + + .. code-block:: python + + >>> await Band.select().where( + ... Band.id.join_on(Concert.band_1).starts >= datetime.datetime( + ... year=2025, month=1, day=1 + ... ) + ... ) + + Use whichever you prefer, and whichever suits the situation best. + +Subqueries can also be passed into the ``not_in`` clause: .. code-block:: python - b = Band - b.select().where( - b.name.not_in(['Rustaceans']) - ).run_sync() + await Band.select().where( + Band.id.not_in( + Concert.select(Concert.band_1).where( + Concert.starts >= datetime.datetime(year=2025, month=1, day=1) + ) + ) + ) + ------------------------------------------------------------------------------- -is_null / is_not_null ---------------------- +``is_null`` / ``is_not_null`` +----------------------------- These queries work, but some linters will complain about doing a comparison -with None: +with ``None``: .. code-block:: python - b = Band - # Fetch all bands with a manager - b.select().where( - b.manager != None - ).run_sync() + await Band.select().where( + Band.manager != None + ) # Fetch all bands without a manager - b.select().where( - b.manager == None - ).run_sync() + await Band.select().where( + Band.manager == None + ) -To avoid the linter errors, you can use `is_null` and `is_not_null` instead. +To avoid the linter errors, you can use ``is_null`` and ``is_not_null`` +instead. .. code-block:: python - b = Band - # Fetch all bands with a manager - b.select().where( - b.manager.is_not_null() - ).run_sync() + await Band.select().where( + Band.manager.is_not_null() + ) # Fetch all bands without a manager - b.select().where( - b.manager.is_null() - ).run_sync() + await Band.select().where( + Band.manager.is_null() + ) ------------------------------------------------------------------------------- @@ -154,14 +195,13 @@ You can make complex ``where`` queries using ``&`` for AND, and ``|`` for OR. .. code-block:: python - b = Band - b.select().where( - (b.popularity >= 100) & (b.popularity < 1000) - ).run_sync() + await Band.select().where( + (Band.popularity >= 100) & (Band.popularity < 1000) + ) - b.select().where( - (b.popularity >= 100) | (b.name == 'Pythonistas') - ).run_sync() + await Band.select().where( + (Band.popularity >= 100) | (Band.name == 'Pythonistas') + ) You can make really complex ``where`` clauses if you so choose - just be careful to include brackets in the correct place. @@ -170,25 +210,39 @@ careful to include brackets in the correct place. ((b.popularity >= 100) & (b.manager.name == 'Guido')) | (b.popularity > 1000) +Multiple ``where`` clauses +~~~~~~~~~~~~~~~~~~~~~~~~~~ + Using multiple ``where`` clauses is equivalent to an AND. .. code-block:: python - b = Band - # These are equivalent: - b.select().where( - (b.popularity >= 100) & (b.popularity < 1000) - ).run_sync() + await Band.select().where( + (Band.popularity >= 100) & (Band.popularity < 1000) + ) - b.select().where( - b.popularity >= 100 + await Band.select().where( + Band.popularity >= 100 ).where( - b.popularity < 1000 - ).run_sync() + Band.popularity < 1000 + ) -Using And / Or directly -~~~~~~~~~~~~~~~~~~~~~~~ +Also, multiple arguments inside ``where`` clause is equivalent to an AND. + +.. code-block:: python + + # These are equivalent: + await Band.select().where( + (Band.popularity >= 100) & (Band.popularity < 1000) + ) + + await Band.select().where( + Band.popularity >= 100, Band.popularity < 1000 + ) + +Using ``And`` / ``Or`` directly +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Rather than using the ``|`` and ``&`` characters, you can use the ``And`` and ``Or`` classes, which are what's used under the hood. @@ -197,19 +251,17 @@ Rather than using the ``|`` and ``&`` characters, you can use the ``And`` and from piccolo.columns.combination import And, Or - b = Band - - b.select().where( + await Band.select().where( Or( - And(b.popularity >= 100, b.popularity < 1000), - b.name == 'Pythonistas' + And(Band.popularity >= 100, Band.popularity < 1000), + Band.name == 'Pythonistas' ) - ).run_sync() + ) ------------------------------------------------------------------------------- -WhereRaw --------- +``WhereRaw`` +------------ In certain situations you may want to have raw SQL in your where clause. @@ -217,9 +269,9 @@ In certain situations you may want to have raw SQL in your where clause. from piccolo.columns.combination import WhereRaw - Band.select().where( + await Band.select().where( WhereRaw("name = 'Pythonistas'") - ).run_sync() + ) It's important to parameterise your SQL statements if the values come from an untrusted source, otherwise it could lead to a SQL injection attack. @@ -230,9 +282,9 @@ untrusted source, otherwise it could lead to a SQL injection attack. value = "Could be dangerous" - Band.select().where( + await Band.select().where( WhereRaw("name = {}", value) - ).run_sync() + ) ``WhereRaw`` can be combined into complex queries, just as you'd expect: @@ -240,7 +292,37 @@ untrusted source, otherwise it could lead to a SQL injection attack. from piccolo.columns.combination import WhereRaw - b = Band - b.select().where( - WhereRaw("name = 'Pythonistas'") | (b.popularity > 1000) - ).run_sync() + await Band.select().where( + WhereRaw("name = 'Pythonistas'") | (Band.popularity > 1000) + ) + +------------------------------------------------------------------------------- + +Joins +----- + +The ``where`` clause has full support for joins. For example: + +.. code-block:: python + + >>> await Band.select(Band.name).where(Band.manager.name == 'Guido') + [{'name': 'Pythonistas'}] + +------------------------------------------------------------------------------- + +Conditional ``where`` clauses +----------------------------- + +You can add ``where`` clauses conditionally (e.g. based on user input): + +.. code-block:: python + + async def get_band_names(only_popular_bands: bool) -> list[str]: + query = Band.select(Band.name).output(as_list=True) + + if only_popular_bands: + query = query.where(Band.popularity >= 1000) + + return await query + +.. hint:: This works with all clauses, not just ``where`` clauses. diff --git a/docs/src/piccolo/query_types/alter.rst b/docs/src/piccolo/query_types/alter.rst index 4d6f69cb2..c7fe0fcbc 100644 --- a/docs/src/piccolo/query_types/alter.rst +++ b/docs/src/piccolo/query_types/alter.rst @@ -7,6 +7,7 @@ This is used to modify an existing table. .. hint:: You can use migrations instead of manually altering the schema - see :ref:`Migrations`. +------------------------------------------------------------------------------- add_column ---------- @@ -15,8 +16,9 @@ Used to add a column to an existing table. .. code-block:: python - Band.alter().add_column('members', Integer()).run_sync() + await Band.alter().add_column('members', Integer()) +------------------------------------------------------------------------------- drop_column ----------- @@ -25,8 +27,9 @@ Used to drop an existing column. .. code-block:: python - Band.alter().drop_column('popularity').run_sync() + await Band.alter().drop_column('popularity') +------------------------------------------------------------------------------- drop_table ---------- @@ -35,8 +38,27 @@ Used to drop the table - use with caution! .. code-block:: python - Band.alter().drop_table().run_sync() + await Band.alter().drop_table() +drop_db_tables / drop_db_tables_sync +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you have several tables which you want to drop, you can use +:func:`drop_db_tables ` or +:func:`drop_db_tables_sync `. The tables +will be dropped in the correct order based on their foreign keys. + +.. code-block:: python + + # async version + >>> from piccolo.table import drop_db_tables + >>> await drop_db_tables(Band, Manager) + + # sync version + >>> from piccolo.table import drop_db_tables_sync + >>> drop_db_tables_sync(Band, Manager) + +------------------------------------------------------------------------------- rename_column ------------- @@ -45,8 +67,9 @@ Used to rename an existing column. .. code-block:: python - Band.alter().rename_column(Band.popularity, 'rating').run_sync() + await Band.alter().rename_column(Band.popularity, 'rating') +------------------------------------------------------------------------------- set_null -------- @@ -56,11 +79,35 @@ Set whether a column is nullable or not. .. code-block:: python # To make a row nullable: - Band.alter().set_null(Band.name, True).run_sync() + await Band.alter().set_null(Band.name, True) # To stop a row being nullable: - Band.alter().set_null(Band.name, False).run_sync() + await Band.alter().set_null(Band.name, False) + +------------------------------------------------------------------------------- + +set_schema +---------- + +Used to change the `schema `_ +which a table belongs to. + +.. code-block:: python + + await Band.alter().set_schema('schema_1') + +Schemas are a way of organising the tables within a database. Only Postgres and +Cockroach support schemas. :ref:`Learn more here `. + +After changing a table's schema, you need to update your ``Table`` accordingly, +otherwise subsequent queries will fail, as they'll be trying to find the table +in the old schema. + +.. code-block:: python + + Band._meta.schema = 'schema_1' +------------------------------------------------------------------------------- set_unique ---------- @@ -70,7 +117,7 @@ Used to change whether a column is unique or not. .. code-block:: python # To make a row unique: - Band.alter().set_unique(Band.name, True).run_sync() + await Band.alter().set_unique(Band.name, True) # To stop a row being unique: - Band.alter().set_unique(Band.name, False).run_sync() + await Band.alter().set_unique(Band.name, False) diff --git a/docs/src/piccolo/query_types/count.rst b/docs/src/piccolo/query_types/count.rst new file mode 100644 index 000000000..794d125bc --- /dev/null +++ b/docs/src/piccolo/query_types/count.rst @@ -0,0 +1,140 @@ +.. _Count: + +Count +===== + +The ``count`` query makes it really easy to retrieve the number of rows in a +table: + +.. code-block:: python + + >>> await Band.count() + 3 + +It's equivalent to this ``select`` query: + +.. code-block:: python + + from piccolo.query.functions.aggregate import Count + + >>> response = await Band.select(Count()) + >>> response[0]['count'] + 3 + +As you can see, the ``count`` query is more convenient. + +Non-null columns +---------------- + +If you want to retrieve the number of rows where a given column isn't null, we +can do so as follows: + +.. code-block:: python + + await Band.count(column=Band.name) + + # Or simply: + await Band.count(Band.name) + +Note, this is equivalent to: + +.. code-block:: python + + await Band.count().where(Band.name.is_not_null()) + +Example +~~~~~~~ + +If we have the following database table: + +.. code-block:: python + + class Band(Table): + name = Varchar() + popularity = Integer(null=True) + +With the following data: + +.. table:: + :widths: auto + + ============ ========== + name popularity + ============ ========== + Pythonistas 1000 + Rustaceans 800 + C-Sharps ``null`` + ============ ========== + +Then we get the following results: + +.. code-block:: python + + >>> await Band.count() + 3 + + >>> await Band.count(Band.popularity) + 2 + +distinct +-------- + +We can count the number of distinct (i.e. unique) rows. + +.. code-block:: python + + await Band.count(distinct=[Band.name]) + + # This also works - use whichever you prefer: + await Band.count().distinct([Band.name]) + +With the following data: + +.. table:: + :widths: auto + + ============ ========== + name popularity + ============ ========== + Pythonistas 1000 + Pythonistas 1000 + Pythonistas 800 + Rustaceans 800 + ============ ========== + +Note how we have duplicate band names. + +.. hint:: + This is bad database design as we should add a unique constraint to + prevent this, but go with it for this example! + +Let's compare queries with and without ``distinct``: + +.. code-block:: python + + >>> await Band.count() + 4 + + >>> await Band.count(distinct=[Band.name]) + 2 + +We can specify multiple columns: + +.. code-block:: python + + >>> await Band.count(distinct=[Band.name, Band.popularity]) + 3 + +In the above example, this means we count rows where the combination of +``name`` and ``popularity`` is unique. + +So ``('Pythonistas', 1000)`` is a distinct value from ``('Pythonistas', 800)``, +because even though the ``name`` is the same, the ``popularity`` is different. + +Clauses +------- + +where +~~~~~ + +See :ref:`where`. diff --git a/docs/src/piccolo/query_types/create_table.rst b/docs/src/piccolo/query_types/create_table.rst index 726b90329..569462d3c 100644 --- a/docs/src/piccolo/query_types/create_table.rst +++ b/docs/src/piccolo/query_types/create_table.rst @@ -9,7 +9,7 @@ This creates the table and columns in the database. .. code-block:: python - >>> Band.create_table().run_sync() + >>> await Band.create_table() [] @@ -17,5 +17,23 @@ To prevent an error from being raised if the table already exists: .. code-block:: python - >>> Band.create_table(if_not_exists=True).run_sync() + >>> await Band.create_table(if_not_exists=True) [] + +create_db_tables / create_db_tables_sync +---------------------------------------- + +You can create multiple tables at once. + +This function will automatically sort tables based on their foreign keys so +they're created in the right order: + +.. code-block:: python + + # async version + >>> from piccolo.table import create_db_tables + >>> await create_db_tables(Band, Manager, if_not_exists=True) + + # sync version + >>> from piccolo.table import create_db_tables_sync + >>> create_db_tables_sync(Band, Manager, if_not_exists=True) diff --git a/docs/src/piccolo/query_types/delete.rst b/docs/src/piccolo/query_types/delete.rst index 80cb42fb4..54401f9f3 100644 --- a/docs/src/piccolo/query_types/delete.rst +++ b/docs/src/piccolo/query_types/delete.rst @@ -7,9 +7,11 @@ This deletes any matching rows from the table. .. code-block:: python - >>> Band.delete().where(Band.name == 'Rustaceans').run_sync() + >>> await Band.delete().where(Band.name == 'Rustaceans') [] +------------------------------------------------------------------------------- + force ----- @@ -19,16 +21,23 @@ the data from a table. .. code-block:: python - >>> Band.delete().run_sync() + >>> await Band.delete() Raises: DeletionError # Works fine: - >>> Band.delete(force=True).run_sync() + >>> await Band.delete(force=True) [] +------------------------------------------------------------------------------- + Query clauses ------------- +returning +~~~~~~~~~ + +See :ref:`Returning`. + where ~~~~~ diff --git a/docs/src/piccolo/query_types/django_comparison.rst b/docs/src/piccolo/query_types/django_comparison.rst index be30c80be..d16c03d57 100644 --- a/docs/src/piccolo/query_types/django_comparison.rst +++ b/docs/src/piccolo/query_types/django_comparison.rst @@ -6,9 +6,51 @@ Django Comparison Here are some common queries, showing how they're done in Django vs Piccolo. All of the Piccolo examples can also be run :ref:`asynchronously`. +------------------------------------------------------------------------------- + Queries ------- +get +~~~ + +They are very similar, except Django raises an ``ObjectDoesNotExist`` exception +if no match is found, whilst Piccolo returns ``None``. + +.. code-block:: python + + # Django + >>> Band.objects.get(name="Pythonistas") + + >>> Band.objects.get(name="DOESN'T EXIST") # ObjectDoesNotExist! + + # Piccolo + >>> Band.objects().get(Band.name == 'Pythonistas').run_sync() + + >>> Band.objects().get(Band.name == "DOESN'T EXIST").run_sync() + None + + +get_or_create +~~~~~~~~~~~~~ + +.. code-block:: python + + # Django + band, created = Band.objects.get_or_create(name="Pythonistas") + >>> band + + >>> created + True + + # Piccolo + >>> band = Band.objects().get_or_create(Band.name == 'Pythonistas').run_sync() + >>> band + + >>> band._was_created + True + + create ~~~~~~ @@ -39,7 +81,7 @@ update >>> band.save() # Piccolo - >>> band = Band.objects().where(Band.name == 'Pythonistas').first().run_sync() + >>> band = Band.objects().get(Band.name == 'Pythonistas').run_sync() >>> band >>> band.name = "Amazing Band" @@ -57,7 +99,7 @@ Individual rows: >>> band.delete() # Piccolo - >>> band = Band.objects().where(Band.name == 'Pythonistas').first().run_sync() + >>> band = Band.objects().get(Band.name == 'Pythonistas').run_sync() >>> band.remove().run_sync() In bulk: @@ -68,7 +110,7 @@ In bulk: >>> Band.objects.filter(popularity__lt=1000).delete() # Piccolo - >>> Band.delete().where(Band.popularity < 1000).delete().run_sync() + >>> Band.delete().where(Band.popularity < 1000).run_sync() filter ~~~~~~ @@ -108,10 +150,57 @@ With ``flat=True``: >>> Band.select(Band.name).output(as_list=True).run_sync() ['Pythonistas', 'Rustaceans'] +select_related +~~~~~~~~~~~~~~ + +Django has an optimisation called ``select_related`` which reduces the number +of SQL queries required when accessing related objects. + +.. code-block:: python + + # Django + band = Band.objects.get(name='Pythonistas') + >>> band.manager # This triggers another db query + + + # Django, with select_related + band = Band.objects.select_related('manager').get(name='Pythonistas') + >>> band.manager # Manager is pre-cached, so there's no extra db query + + +Piccolo has something similar: + +.. code-block:: python + + # Piccolo + band = Band.objects(Band.manager).get(Band.name == 'Pythonistas').run_sync() + >>> band.manager + + +------------------------------------------------------------------------------- + +Schema +------ + +OneToOneField +~~~~~~~~~~~~~ + +To do this in Piccolo, use a ``ForeignKey`` with a unique constraint - see +:ref:`One to One`. + ------------------------------------------------------------------------------- -Database Settings +Database settings ----------------- In Django you configure your database in ``settings.py``. With Piccolo, you define an ``Engine`` in ``piccolo_conf.py``. See :ref:`Engines`. + +------------------------------------------------------------------------------- + +Creating a new project +---------------------- + +With Django you use ``django-admin startproject mysite``. + +In Piccolo you use ``piccolo asgi new`` (see :ref:`ASGICommand`). diff --git a/docs/src/piccolo/query_types/exists.rst b/docs/src/piccolo/query_types/exists.rst index 8ef5ab9f2..982ba54f5 100644 --- a/docs/src/piccolo/query_types/exists.rst +++ b/docs/src/piccolo/query_types/exists.rst @@ -7,9 +7,11 @@ This checks whether any rows exist which match the criteria. .. code-block:: python - >>> Band.exists().where(Band.name == 'Pythonistas').run_sync() + >>> await Band.exists().where(Band.name == 'Pythonistas') True +------------------------------------------------------------------------------- + Query clauses ------------- diff --git a/docs/src/piccolo/query_types/index.rst b/docs/src/piccolo/query_types/index.rst index ca650dec0..19a012aa2 100644 --- a/docs/src/piccolo/query_types/index.rst +++ b/docs/src/piccolo/query_types/index.rst @@ -4,7 +4,7 @@ Query Types There are many different queries you can perform using Piccolo. The main ways to query data are with :ref:`Select`, which returns data as -dictionaries, and :ref:`Objects` , which returns data as class instances, like a +dictionaries, and :ref:`Objects`, which returns data as class instances, like a typical ORM. .. toctree:: @@ -12,14 +12,27 @@ typical ORM. ./select ./objects - ./alter + ./count + ./alter ./create_table ./delete ./exists ./insert ./raw ./update + +------------------------------------------------------------------------------- + +Features +-------- + +.. toctree:: + :maxdepth: 1 + ./transactions + ./joins + +------------------------------------------------------------------------------- Comparisons ----------- diff --git a/docs/src/piccolo/query_types/insert.rst b/docs/src/piccolo/query_types/insert.rst index 8af1bf69c..6648d6d0b 100644 --- a/docs/src/piccolo/query_types/insert.rst +++ b/docs/src/piccolo/query_types/insert.rst @@ -3,31 +3,47 @@ Insert ====== -This is used to insert rows into the table. +This is used to bulk insert rows into the table: .. code-block:: python - >>> Band.insert(Band(name="Pythonistas")).run_sync() - [{'id': 3}] + await Band.insert( + Band(name="Pythonistas"), + Band(name="Darts"), + Band(name="Gophers") + ) + +------------------------------------------------------------------------------- + +``add`` +------- -We can insert multiple rows in one go: +If we later decide to insert additional rows, we can use the ``add`` method: .. code-block:: python - Band.insert( - Band(name="Darts"), - Band(name="Gophers") - ).run_sync() + query = Band.insert(Band(name="Pythonistas")) -add ---- + if other_bands: + query = query.add( + Band(name="Darts"), + Band(name="Gophers") + ) -You can also compose it as follows: + await query -.. code-block:: python +------------------------------------------------------------------------------- - Band.insert().add( - Band(name="Darts") - ).add( - Band(name="Gophers") - ).run_sync() +Query clauses +------------- + +on_conflict +~~~~~~~~~~~ + +See :ref:`On_Conflict`. + + +returning +~~~~~~~~~ + +See :ref:`Returning`. diff --git a/docs/src/piccolo/query_types/joins.rst b/docs/src/piccolo/query_types/joins.rst new file mode 100644 index 000000000..3d89c6da2 --- /dev/null +++ b/docs/src/piccolo/query_types/joins.rst @@ -0,0 +1,73 @@ +Joins +===== + +Joins are handled automatically by Piccolo. They work everywhere you'd expect +(select queries, where clauses, etc.). + +A `fluent interface `_ is used, +which lets you traverse foreign keys. + +Here's an example of a select query which uses joins (using the +:ref:`example schema `): + +.. code-block:: python + + # This gets the band's name, and the manager's name by joining to the + # manager table: + >>> await Band.select(Band.name, Band.manager.name) + +And a ``where`` clause which uses joins: + +.. code-block:: python + + # This automatically joins with the manager table to perform the where + # clause. It only returns the columns from the band table though by default. + >>> await Band.select().where(Band.manager.name == 'Guido') + +Left joins are used. + +Improved static typing +---------------------- + +You can optionally modify the above queries slightly for powerful static typing +support from tools like Mypy and Pylance: + +.. code-block:: python + + await Band.select(Band.name, Band.manager._.name) + +Notice how we use ``._.`` instead of ``.`` after each foreign key. An easy way +to remember this is ``._.`` looks a bit like a connector in a diagram. + +Static type checkers now know that we're referencing the ``name`` column on the +``Manager`` table, which has many advantages: + +* Autocompletion of column names. +* Easier code navigation (command + click on column names to navigate to the + column definition). +* Most importantly, the detection of typos in column names. + +This works, no matter how many joins are performed. For example: + +.. code-block:: python + + await Concert.select( + Concert.band_1._.name, + Concert.band_1._.manager._.name, + ) + +.. note:: You may wonder why this syntax is required. We're operating within + the limits of Python's typing support, which is still fairly young. In the + future we will hopefully be able to offer identical static typing support + for ``Band.manager.name`` and ``Band.manager._.name``. But even then, + the ``._.`` syntax will still be supported. + +``join_on`` +----------- + +Joins are usually performed using ``ForeignKey`` columns, though there may be +situations where you want to join using a column which isn't a ``ForeignKey``. + +You can do this using :meth:`join_on `. + +It's generally best to join on unique columns. diff --git a/docs/src/piccolo/query_types/objects.rst b/docs/src/piccolo/query_types/objects.rst index fd9cac66a..691ed55c2 100644 --- a/docs/src/piccolo/query_types/objects.rst +++ b/docs/src/piccolo/query_types/objects.rst @@ -11,54 +11,133 @@ can manipulate them, and save the changes back to the database. In Piccolo, an instance of a ``Table`` class represents a row. Let's do some examples. +------------------------------------------------------------------------------- + Fetching objects ---------------- -To get all objects: +To get all rows: + +.. code-block:: python + + >>> await Band.objects() + [, , ] + +To limit the number of rows returned, use the :ref:`order_by` and :ref:`limit` +clauses: .. code-block:: python - >>> Band.objects().run_sync() - [, ] + >>> await Band.objects().order_by(Band.popularity, ascending=False).limit(2) + [, ] -To get certain rows: +To filter the rows we use the :ref:`where` clause: .. code-block:: python - >>> Band.objects().where(Band.name == 'Pythonistas').run_sync() + >>> await Band.objects().where(Band.name == 'Pythonistas') [] -To get the first row: +To get a single row (or ``None`` if it doesn't exist) use the :ref:`first` +clause: + +.. code-block:: python + + >>> await Band.objects().where(Band.name == 'Pythonistas').first() + + +Alternatively, you can use this abbreviated syntax: .. code-block:: python - >>> Band.objects().first().run_sync() + >>> await Band.objects().get(Band.name == 'Pythonistas') -You'll notice that the API is similar to :ref:`Select` - except it returns all -columns. +You'll notice that the API is similar to :ref:`Select` (expect with ``select`` +you can specify which columns are returned). + +------------------------------------------------------------------------------- Creating objects ---------------- +You can pass the column values using kwargs: + .. code-block:: python >>> band = Band(name="C-Sharps", popularity=100) - >>> band.save().run_sync() + >>> await band.save() + +Alternatively, you can pass in a dictionary, which is friendlier to static +analysis tools like Mypy (it can easily detect typos in the column names): + +.. code-block:: python + + >>> band = Band({Band.name: "C-Sharps", Band.popularity: 100}) + >>> await band.save() + +We also have this shortcut which combines the above into a single line: + +.. code-block:: python + + >>> band = await Band.objects().create(name="C-Sharps", popularity=100) + +------------------------------------------------------------------------------- Updating objects ---------------- -Objects have a ``save`` method, which is convenient for updating values: +``save`` +~~~~~~~~ + +Objects have a :meth:`save ` method, which is +convenient for updating values: .. code-block:: python - pythonistas = Band.objects().where( + band = await Band.objects().where( Band.name == 'Pythonistas' - ).first().run_sync() + ).first() + + band.popularity = 100000 + + # This saves all values back to the database. + await band.save() + + # Or specify specific columns to save: + await band.save([Band.popularity]) + +``update_self`` +~~~~~~~~~~~~~~~ + +The :meth:`save ` method is fine in the majority of +cases, but there are some situations where the :meth:`update_self ` +method is preferable. + +For example, if we want to increment the ``popularity`` value, we can do this: + +.. code-block:: python + + await band.update_self({ + Band.popularity: Band.popularity + 1 + }) + +Which does the following: + +* Increments the popularity in the database +* Assigns the new value to the object + +This is safer than: - pythonistas.popularity = 100000 - pythonistas.save().run_sync() +.. code-block:: python + + band.popularity += 1 + await band.save() + +Because ``update_self`` increments the current ``popularity`` value in the +database, not the one on the object, which might be out of date. + +------------------------------------------------------------------------------- Deleting objects ---------------- @@ -67,60 +146,311 @@ Similarly, we can delete objects, using the ``remove`` method. .. code-block:: python - pythonistas = Band.objects().where( + band = await Band.objects().where( Band.name == 'Pythonistas' - ).first().run_sync() + ).first() - pythonistas.remove().run_sync() + await band.remove() -get_related ------------ +------------------------------------------------------------------------------- -If you have an object with a foreign key, and you want to fetch the related -object, you can do so using ``get_related``. +Fetching related objects +------------------------ + +``get_related`` +~~~~~~~~~~~~~~~ + +If you have an object from a table with a :class:`ForeignKey ` +column, and you want to fetch the related row as an object, you can do so +using ``get_related``. + +.. code-block:: python + + band = await Band.objects().where( + Band.name == 'Pythonistas' + ).first() + + manager = await band.get_related(Band.manager) + >>> manager + + >>> manager.name + 'Guido' + +It works multiple levels deep - for example: + +.. code-block:: python + + concert = await Concert.objects().first() + manager = await concert.get_related(Concert.band_1.manager) + +Prefetching related objects +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can also prefetch the rows from related tables, and store them as child +objects. To do this, pass :class:`ForeignKey ` +columns into ``objects``, which refer to the related rows you want to load. .. code-block:: python - pythonistas = Band.objects().where( + band = await Band.objects(Band.manager).where( Band.name == 'Pythonistas' - ).first().run_sync() + ).first() - manager = pythonistas.get_related(Band.manager).run_sync() - >>> print(manager.name) + >>> band.manager + + >>> band.manager.name 'Guido' +If you have a table containing lots of ``ForeignKey`` columns, and want to +prefetch them all you can do so using ``all_related``. + +.. code-block:: python + + ticket = await Ticket.objects( + Ticket.concert.all_related() + ).first() + + # Any intermediate objects will also be loaded: + >>> ticket.concert + + + >>> ticket.concert.band_1 + + >>> ticket.concert.band_2 + + +You can manipulate these nested objects, and save the values back to the +database, just as you would expect: + +.. code-block:: python + + ticket.concert.band_1.name = 'Pythonistas 2' + await ticket.concert.band_1.save() + +Instead of passing the :class:`ForeignKey ` +columns into the ``objects`` method, you can use the ``prefetch`` clause if you +prefer. + +.. code-block:: python + + # These are equivalent: + ticket = await Ticket.objects( + Ticket.concert.all_related() + ).first() + + ticket = await Ticket.objects().prefetch( + Ticket.concert.all_related() + ).first() + +------------------------------------------------------------------------------- + +``get_or_create`` +----------------- + +With ``get_or_create`` you can get an existing record matching the criteria, +or create a new one with the ``defaults`` arguments: + +.. code-block:: python + + band = await Band.objects().get_or_create( + Band.name == 'Pythonistas', defaults={Band.popularity: 100} + ) + + # Or using string column names + band = await Band.objects().get_or_create( + Band.name == 'Pythonistas', defaults={'popularity': 100} + ) + +You can find out if an existing row was found, or if a new row was created: + +.. code-block:: python + + band = await Band.objects.get_or_create( + Band.name == 'Pythonistas' + ) + band._was_created # True if it was created, otherwise False if it was already in the db + +Complex where clauses are supported, but only within reason. For example: + +.. code-block:: python + + # This works OK: + band = await Band.objects().get_or_create( + (Band.name == 'Pythonistas') & (Band.popularity == 1000), + ) + + # This is problematic, as it's unclear what the name should be if we + # need to create the row: + band = await Band.objects().get_or_create( + (Band.name == 'Pythonistas') | (Band.name == 'Rustaceans'), + defaults={'popularity': 100} + ) + +------------------------------------------------------------------------------- + +``to_dict`` +----------- + +If you need to convert an object into a dictionary, you can do so using the +``to_dict`` method. + +.. code-block:: python + + band = await Band.objects().first() + + >>> band.to_dict() + {'id': 1, 'name': 'Pythonistas', 'manager': 1, 'popularity': 1000} + +If you only want a subset of the columns, or want to use aliases for some of +the columns: + +.. code-block:: python + + band = await Band.objects().first() + + >>> band.to_dict(Band.id, Band.name.as_alias('title')) + {'id': 1, 'title': 'Pythonistas'} + +------------------------------------------------------------------------------- + +``refresh`` +----------- + +If you have an object which has gotten stale, and want to refresh it, so it +has the latest data from the database, you can use the +:meth:`refresh ` method. + +.. code-block:: python + + # If we have an instance: + band = await Band.objects().first() + + # And it has gotten stale, we can refresh it: + await band.refresh() + + # Or just refresh certain columns: + await band.refresh([Band.name]) + +It works with ``prefetch`` too: + +.. code-block:: python + + # If we have an instance with a child object: + band = await Band.objects(Band.manager).first() + + # And it has gotten stale, we can refresh it: + await band.refresh() + + # The nested object will also be updated if it was stale: + >>> band.manager.name + "New value" + +``refresh`` is very useful in unit tests: + +.. code-block:: python + + # If we have an instance: + band = await Band.objects().where(Band.name == "Pythonistas").first() + + # Call an API endpoint which updates the object (e.g. with httpx): + await client.patch(f"/band/{band.id}/", json={"popularity": 5000}) + + # Make sure the instance was updated: + await band.refresh() + assert band.popularity == 5000 + +------------------------------------------------------------------------------- + +Comparing objects +----------------- + +If you have two objects, and you want to know whether they refer to the same +row in the database, you can simply use the equality operator: + +.. code-block:: python + + band_1 = await Band.objects().where(Band.name == "Pythonistas").first() + band_2 = await Band.objects().where(Band.name == "Pythonistas").first() + + >>> band_1 == band_2 + True + +It works by comparing the primary key value of each object. It's equivalent to +this: + +.. code-block:: python + + >>> band_1.id == band_2.id + True + +If the object has no primary key value yet (e.g. it uses a ``Serial`` column, +and it hasn't been saved in the database), then the result will always be +``False``: + +.. code-block:: python + + band_1 = Band() + band_2 = Band() + + >>> band_1 == band_2 + False + +If you want to compare every value on the objects, and not just the primary +key, you can use ``to_dict``. For example: + +.. code-block:: python + + >>> band_1.to_dict() == band_2.to_dict() + True + + >>> band_1.popularity = 10_000 + >>> band_1.to_dict() == band_2.to_dict() + False + +------------------------------------------------------------------------------- + Query clauses ------------- batch -~~~~~~~ +~~~~~ See :ref:`batch`. +callback +~~~~~~~~ + +See :ref:`callback`. + +first +~~~~~ + +See :ref:`first`. + limit ~~~~~ -See  :ref:`limit`. +See :ref:`limit`. -offset -~~~~~~ +lock_rows +~~~~~~~~~ -See  :ref:`offset`. +See :ref:`lock_rows`. -first -~~~~~ +offset +~~~~~~ -See  :ref:`first`. +See :ref:`offset`. order_by ~~~~~~~~ -See  :ref:`order_by`. +See :ref:`order_by`. output ~~~~~~ -See  :ref:`output`. +See :ref:`output`. where ~~~~~ diff --git a/docs/src/piccolo/query_types/raw.rst b/docs/src/piccolo/query_types/raw.rst index c06d265f7..4e1e0aaf4 100644 --- a/docs/src/piccolo/query_types/raw.rst +++ b/docs/src/piccolo/query_types/raw.rst @@ -7,16 +7,26 @@ Should you need to, you can execute raw SQL. .. code-block:: python - >>> Band.raw('select * from band').run_sync() - [{'name': 'Pythonistas', 'manager': 1, 'popularity': 1000, 'id': 1}, - {'name': 'Rustaceans', 'manager': 2, 'popularity': 500, 'id': 2}] + >>> await Band.raw('SELECT name FROM band') + [{'name': 'Pythonistas'}] It's recommended that you parameterise any values. Use curly braces ``{}`` as placeholders: .. code-block:: python - >>> Band.raw('select * from band where name = {}', 'Pythonistas').run_sync() + >>> await Band.raw('SELECT * FROM band WHERE name = {}', 'Pythonistas') [{'name': 'Pythonistas', 'manager': 1, 'popularity': 1000, 'id': 1}] .. warning:: Be careful to avoid SQL injection attacks. Don't add any user submitted data into your SQL strings, unless it's parameterised. + + +------------------------------------------------------------------------------- + +Query clauses +------------- + +batch +~~~~~ + +See :ref:`batch`. diff --git a/docs/src/piccolo/query_types/select.rst b/docs/src/piccolo/query_types/select.rst index 16c24bbe5..2025f35f4 100644 --- a/docs/src/piccolo/query_types/select.rst +++ b/docs/src/piccolo/query_types/select.rst @@ -3,13 +3,16 @@ Select ====== -.. hint:: Follow along by installing Piccolo and running `piccolo playground run` - see :ref:`Playground` +.. hint:: Follow along by installing Piccolo and running ``piccolo playground run`` - see :ref:`Playground`. -To get all rows: +Columns +------- + +To get all columns: .. code-block:: python - >>> Band.select().run_sync() + >>> await Band.select() [{'id': 1, 'name': 'Pythonistas', 'manager': 1, 'popularity': 1000}, {'id': 2, 'name': 'Rustaceans', 'manager': 2, 'popularity': 500}] @@ -17,19 +20,22 @@ To get certain columns: .. code-block:: python - >>> Band.select(Band.name).run_sync() + >>> await Band.select(Band.name) [{'name': 'Rustaceans'}, {'name': 'Pythonistas'}] -Or making an alias to make it shorter: +Or use an alias to make it shorter: .. code-block:: python >>> b = Band - >>> b.select(b.name).run_sync() - [{'id': 1, 'name': 'Pythonistas', 'manager': 1, 'popularity': 1000}, - {'id': 2, 'name': 'Rustaceans', 'manager': 2, 'popularity': 500}] + >>> await b.select(b.name) + [{'name': 'Rustaceans'}, {'name': 'Pythonistas'}] + +.. hint:: + All of these examples also work synchronously using ``run_sync`` - + see :ref:`SyncAndAsync`. -.. hint:: All of these examples also work with async by using .run() inside coroutines - see :ref:`SyncAndAsync`. +------------------------------------------------------------------------------- as_alias -------- @@ -38,60 +44,178 @@ By using ``as_alias``, the name of the row can be overriden in the response. .. code-block:: python - >>> Band.select(Band.name.as_alias('title')).run_sync() + >>> await Band.select(Band.name.as_alias('title')) [{'title': 'Rustaceans'}, {'title': 'Pythonistas'}] This is equivalent to ``SELECT name AS title FROM band`` in SQL. +------------------------------------------------------------------------------- + Joins ----- -One of the most powerful things about select is it's support for joins. +One of the most powerful things about ``select`` is it's support for joins. .. code-block:: python - >>> b = Band - >>> b.select(b.name, b.manager.name).run_sync() - [{'name': 'Pythonistas', 'manager.name': 'Guido'}, {'name': 'Rustaceans', 'manager.name': 'Graydon'}] + >>> await Band.select(Band.name, Band.manager.name) + [ + {'name': 'Pythonistas', 'manager.name': 'Guido'}, + {'name': 'Rustaceans', 'manager.name': 'Graydon'} + ] The joins can go several layers deep. .. code-block:: python - c = Concert - c.select( - c.id, - c.band_1.manager.name - ).run_sync() + >>> await Concert.select(Concert.id, Concert.band_1.manager.name) + [{'id': 1, 'band_1.manager.name': 'Guido'}] + +all_columns +~~~~~~~~~~~ + +If you want all of the columns from a related table you can use +``all_columns``, which is a useful shortcut which saves you from typing them +all out: + +.. code-block:: python + + >>> await Band.select(Band.name, Band.manager.all_columns()) + [ + {'name': 'Pythonistas', 'manager.id': 1, 'manager.name': 'Guido'}, + {'name': 'Rustaceans', 'manager.id': 2, 'manager.name': 'Graydon'} + ] + + +In Piccolo < 0.41.0 you had to explicitly unpack ``all_columns``. This is +equivalent to the code above: + +.. code-block:: python + + >>> await Band.select(Band.name, *Band.manager.all_columns()) + + +You can exclude some columns if you like: + +.. code-block:: python + + >>> await Band.select( + ... Band.name, + ... Band.manager.all_columns(exclude=[Band.manager.id]) + ... ) + [ + {'name': 'Pythonistas', 'manager.name': 'Guido'}, + {'name': 'Rustaceans', 'manager.name': 'Graydon'} + ] + + +Strings are supported too if you prefer: + +.. code-block:: python + + >>> await Band.select( + ... Band.name, + ... Band.manager.all_columns(exclude=['id']) + ... ) + [ + {'name': 'Pythonistas', 'manager.name': 'Guido'}, + {'name': 'Rustaceans', 'manager.name': 'Graydon'} + ] + +You can also use ``all_columns`` on the root table, which saves you time if +you have lots of columns. It works identically to related tables: + +.. code-block:: python + + >>> await Band.select( + ... Band.all_columns(exclude=[Band.id]), + ... Band.manager.all_columns(exclude=[Band.manager.id]) + ... ) + [ + {'name': 'Pythonistas', 'popularity': 1000, 'manager.name': 'Guido'}, + {'name': 'Rustaceans', 'popularity': 500, 'manager.name': 'Graydon'} + ] + +Nested +~~~~~~ + +You can also get the response as nested dictionaries, which can be very useful: + +.. code-block:: python + + >>> await Band.select(Band.name, Band.manager.all_columns()).output(nested=True) + [ + {'name': 'Pythonistas', 'manager': {'id': 1, 'name': 'Guido'}}, + {'name': 'Rustaceans', 'manager': {'id': 2, 'manager.name': 'Graydon'}} + ] + +------------------------------------------------------------------------------- String syntax ------------- -Alternatively, you can specify the column names using a string. The +You can specify the column names using a string if you prefer. The disadvantage is you won't have tab completion, but sometimes it's more convenient. .. code-block:: python - Band.select('name').run_sync() + await Band.select('name') # For joins: - Band.select('manager.name').run_sync() + await Band.select('manager.name') + +------------------------------------------------------------------------------- +String functions +---------------- + +Piccolo has lots of string functions built-in. See +``piccolo/query/functions/string.py``. Here's an example using ``Upper``, to +convert values to uppercase: + +.. code-block:: python + + from piccolo.query.functions.string import Upper + + >> await Band.select(Upper(Band.name, alias='name')) + [{'name': 'PYTHONISTAS'}, ...] + +You can also use these within where clauses: + +.. code-block:: python + + from piccolo.query.functions.string import Upper + + >> await Band.select(Band.name).where(Upper(Band.manager.name) == 'GUIDO') + [{'name': 'Pythonistas'}] + +------------------------------------------------------------------------------- + +.. _AggregateFunctions: Aggregate functions ------------------- +.. note:: These can all be used in conjunction with the :ref:`group_by` clause. + Count ~~~~~ -Returns the number of rows which match the query: +.. hint:: You can use the :ref:`count` query as a quick way of getting + the number of rows in a table. + +Returns the number of matching rows. .. code-block:: python - >>> Band.count().where(Band.name == 'Pythonistas').run_sync() - 1 + from piccolo.query.functions.aggregate import Count + + >> await Band.select(Count()).where(Band.popularity > 100) + [{'count': 3}] + +To find out more about the options available, see :class:`Count `. Avg ~~~ @@ -100,8 +224,8 @@ Returns the average for a given column: .. code-block:: python - >>> from piccolo.query import Avg - >>> response = Band.select(Avg(Band.popularity)).first().run_sync() + >>> from piccolo.query.functions.aggregate import Avg + >>> response = await Band.select(Avg(Band.popularity)).first() >>> response["avg"] 750.0 @@ -112,8 +236,8 @@ Returns the sum for a given column: .. code-block:: python - >>> from piccolo.query import Sum - >>> response = Band.select(Sum(Band.popularity)).first().run_sync() + >>> from piccolo.query.functions.aggregate import Sum + >>> response = await Band.select(Sum(Band.popularity)).first() >>> response["sum"] 1500 @@ -124,8 +248,8 @@ Returns the maximum for a given column: .. code-block:: python - >>> from piccolo.query import Max - >>> response = Band.select(Max(Band.popularity)).first().run_sync() + >>> from piccolo.query.functions.aggregate import Max + >>> response = await Band.select(Max(Band.popularity)).first() >>> response["max"] 1000 @@ -136,20 +260,23 @@ Returns the minimum for a given column: .. code-block:: python - >>> from piccolo.query import Min - >>> response = Band.select(Min(Band.popularity)).first().run_sync() + >>> from piccolo.query.functions.aggregate import Min + >>> response = await Band.select(Min(Band.popularity)).first() >>> response["min"] 500 Additional features ~~~~~~~~~~~~~~~~~~~ -You also can chain multiple different aggregate functions in one query: +You also can have multiple different aggregate functions in one query: .. code-block:: python - >>> from piccolo.query import Avg, Sum - >>> response = Band.select(Avg(Band.popularity), Sum(Band.popularity)).first().run_sync() + >>> from piccolo.query.functions.aggregate import Avg, Sum + >>> response = await Band.select( + ... Avg(Band.popularity), + ... Sum(Band.popularity) + ... ).first() >>> response {"avg": 750.0, "sum": 1500} @@ -157,16 +284,36 @@ And can use aliases for aggregate functions like this: .. code-block:: python - >>> from piccolo.query import Avg - >>> response = Band.select(Avg(Band.popularity, alias="popularity_avg")).first().run_sync() - >>> response["popularity_avg"] - 750.0 - # Alternatively, you can use the `as_alias` method. - >>> response = Band.select(Avg(Band.popularity).as_alias("popularity_avg")).first().run_sync() + >>> response = await Band.select( + ... Avg(Band.popularity).as_alias("popularity_avg") + ... ).first() >>> response["popularity_avg"] 750.0 +------------------------------------------------------------------------------- + +SelectRaw +--------- + +In certain situations you may want to have raw SQL in your select query. + +For example, if there's a Postgres function which you want to access, which +isn't supported by Piccolo: + +.. code-block:: python + + from piccolo.query import SelectRaw + + >>> await Band.select( + ... Band.name, + ... SelectRaw("log(popularity) AS log_popularity") + ... ) + [{'name': 'Pythonistas', 'log_popularity': 3.0}] + +.. warning:: Only use SQL that you trust. + +------------------------------------------------------------------------------- Query clauses ------------- @@ -176,6 +323,11 @@ batch See :ref:`batch`. +callback +~~~~~~~~ + +See :ref:`callback`. + columns ~~~~~~~ @@ -183,60 +335,65 @@ By default all columns are returned from the queried table. .. code-block:: python - b = Band # Equivalent to SELECT * from band - b.select().run_sync() + await Band.select() To restrict the returned columns, either pass in the columns into the ``select`` method, or use the ``columns`` method. .. code-block:: python - b = Band # Equivalent to SELECT name from band - b.select().columns(b.name).run_sync() + await Band.select(Band.name) + + # Or alternatively: + await Band.select().columns(Band.name) The ``columns`` method is additive, meaning you can chain it to add additional columns. .. code-block:: python - b = Band - b.select().columns(b.name).columns(b.manager).run_sync() + await Band.select().columns(Band.name).columns(Band.manager) # Or just define it one go: - b.select().columns(b.name, b.manager).run_sync() + await Band.select().columns(Band.name, Band.manager) +distinct +~~~~~~~~ + +See :ref:`distinct`. first ~~~~~ -See  :ref:`first`. +See :ref:`first`. group_by ~~~~~~~~ -See  :ref:`group_by`. +See :ref:`group_by`. limit ~~~~~ -See  :ref:`limit`. +See :ref:`limit`. -offset -~~~~~~ -See  :ref:`offset`. +lock_rows +~~~~~~~~~ -distinct -~~~~~~~~ +See :ref:`lock_rows`. + +offset +~~~~~~ -See  :ref:`distinct`. +See :ref:`offset`. order_by ~~~~~~~~ -See  :ref:`order_by`. +See :ref:`order_by`. output ~~~~~~ @@ -246,4 +403,4 @@ See :ref:`output`. where ~~~~~ -See  :ref:`where`. +See :ref:`where`. diff --git a/docs/src/piccolo/query_types/transactions.rst b/docs/src/piccolo/query_types/transactions.rst index 919cebe30..803d1cfa3 100644 --- a/docs/src/piccolo/query_types/transactions.rst +++ b/docs/src/piccolo/query_types/transactions.rst @@ -8,6 +8,22 @@ Transactions allow multiple queries to be committed only once successful. This is useful for things like migrations, where you can't have it fail in an inbetween state. +------------------------------------------------------------------------------- + +Accessing the ``Engine`` +------------------------ + +In the examples below we need to access the database ``Engine``. + +Each ``Table`` contains a reference to its ``Engine``, which is the easiest +way to access it. For example, with our ``Band`` table: + +.. code-block:: python + + DB = Band._meta.db + +------------------------------------------------------------------------------- + Atomic ------ @@ -16,7 +32,7 @@ transaction before running it. .. code-block:: python - transaction = Band._meta.db.atomic() + transaction = DB.atomic() transaction.add(Manager.create_table()) transaction.add(Concert.create_table()) await transaction.run() @@ -24,6 +40,8 @@ transaction before running it. # You're also able to run this synchronously: transaction.run_sync() +------------------------------------------------------------------------------- + Transaction ----------- @@ -32,10 +50,137 @@ async. .. code-block:: python - async with Band._meta.db.transaction(): - await Manager.create_table().run() - await Concert.create_table().run() + async with DB.transaction(): + await Manager.create_table() + await Concert.create_table() + +Commit +~~~~~~ + +The transaction is automatically committed when you exit the context manager. + +.. code-block:: python + + async with DB.transaction(): + await query_1 + await query_2 + # Automatically committed if the code reaches here. + +You can manually commit it if you prefer: + +.. code-block:: python + + async with DB.transaction() as transaction: + await query_1 + await query_2 + await transaction.commit() + print('transaction committed!') + +Rollback +~~~~~~~~ If an exception is raised within the body of the context manager, then the transaction is automatically rolled back. The exception is still propagated though. + +Rather than raising an exception, if you want to rollback a transaction +manually you can do so as follows: + +.. code-block:: python + + async with DB.transaction() as transaction: + await Manager.create_table() + await Band.create_table() + await transaction.rollback() + +------------------------------------------------------------------------------- + +Nested transactions +------------------- + +Nested transactions aren't supported in Postgres, but we can achieve something +similar using `savepoints `_. + +Nested context managers +~~~~~~~~~~~~~~~~~~~~~~~ + +If you have nested context managers, for example: + +.. code-block:: python + + async with DB.transaction(): + async with DB.transaction(): + ... + +By default, the inner context manager does nothing, as we're already inside a +transaction. + +You can change this behaviour using ``allow_nested=False``, in which case a +``TransactionError`` is raised if you try creating a transaction when one +already exists. + +.. code-block:: python + + async with DB.transaction(): + async with DB.transaction(allow_nested=False): + # TransactionError('A transaction is already active.') + +``transaction_exists`` +~~~~~~~~~~~~~~~~~~~~~~ + +You can check whether your code is currently inside a transaction using the +following: + +.. code-block:: python + + >>> DB.transaction_exists() + True + +------------------------------------------------------------------------------- + +Savepoints +---------- + +Postgres supports savepoints, which is a way of partially rolling back a +transaction. + +.. code-block:: python + + async with DB.transaction() as transaction: + await Band.insert(Band(name='Pythonistas')) + + savepoint_1 = await transaction.savepoint() + + await Band.insert(Band(name='Terrible band')) + + # Oops, I made a mistake! + await savepoint_1.rollback_to() + +In the above example, the first query will be committed, but not the second. + +Named savepoints +~~~~~~~~~~~~~~~~ + +By default, we assign a name to the savepoint for you. But you can explicitly +give it a name: + +.. code-block:: python + + await transaction.savepoint('my_savepoint') + +This means you can rollback to this savepoint at any point just using the name: + +.. code-block:: python + + await transaction.rollback_to('my_savepoint') + +------------------------------------------------------------------------------- + +Transaction types +----------------- + +SQLite +~~~~~~ + +For SQLite you may want to specify the :ref:`transaction type `, +as it can have an effect on how well the database handles concurrent requests. diff --git a/docs/src/piccolo/query_types/update.rst b/docs/src/piccolo/query_types/update.rst index 0d24e5db4..acd0e1bb8 100644 --- a/docs/src/piccolo/query_types/update.rst +++ b/docs/src/piccolo/query_types/update.rst @@ -1,3 +1,4 @@ + .. _Update: Update @@ -7,20 +8,43 @@ This is used to update any rows in the table which match the criteria. .. code-block:: python - >>> Band.update({ - >>> Band.name: 'Pythonistas 2' - >>> }).where( - >>> Band.name == 'Pythonistas' - >>> ).run_sync() + >>> await Band.update({ + ... Band.name: 'Pythonistas 2' + ... }).where( + ... Band.name == 'Pythonistas' + ... ) [] -As well as replacing values with new ones, you can also modify existing values, for -instance by adding to an integer. +------------------------------------------------------------------------------- + +force +----- + +Piccolo won't let you run an update query without a :ref:`where clause `, +unless you explicitly tell it to do so. This is to prevent accidentally +overwriting the data in a table. + +.. code-block:: python + + >>> await Band.update() + Raises: UpdateError + # Works fine: + >>> await Band.update({Band.popularity: 0}, force=True) + + # Or just add a where clause: + >>> await Band.update({Band.popularity: 0}).where(Band.popularity < 50) + +------------------------------------------------------------------------------- Modifying values ---------------- +As well as replacing values with new ones, you can also modify existing values, +for instance by adding an integer. + +You can currently only combine two values together at a time. + Integer columns ~~~~~~~~~~~~~~~ @@ -29,29 +53,44 @@ You can add / subtract / multiply / divide values: .. code-block:: python # Add 100 to the popularity of each band: - Band.update({ - Band.popularity: Band.popularity + 100 - }).run_sync() + await Band.update( + { + Band.popularity: Band.popularity + 100 + }, + force=True + ) # Decrease the popularity of each band by 100. - Band.update({ - Band.popularity: Band.popularity - 100 - }).run_sync() + await Band.update( + { + Band.popularity: Band.popularity - 100 + }, + force=True + ) # Multiply the popularity of each band by 10. - Band.update({ - Band.popularity: Band.popularity * 10 - }).run_sync() + await Band.update( + { + Band.popularity: Band.popularity * 10 + }, + force=True + ) # Divide the popularity of each band by 10. - Band.update({ - Band.popularity: Band.popularity / 10 - }).run_sync() + await Band.update( + { + Band.popularity: Band.popularity / 10 + }, + force=True + ) # You can also use the operators in reverse: - Band.update({ - Band.popularity: 2000 - Band.popularity - }).run_sync() + await Band.update( + { + Band.popularity: 2000 - Band.popularity + }, + force=True + ) Varchar / Text columns ~~~~~~~~~~~~~~~~~~~~~~ @@ -61,27 +100,130 @@ You can concatenate values: .. code-block:: python # Append "!!!" to each band name. - Band.update({ - Band.name: Band.name + "!!!" - }).run_sync() + await Band.update( + { + Band.name: Band.name + "!!!" + }, + force=True + ) # Concatenate the values in each column: - Band.update({ - Band.name: Band.name + Band.name - }).run_sync() + await Band.update( + { + Band.name: Band.name + Band.name + }, + force=True + ) # Prepend "!!!" to each band name. - Band.update({ - Band.popularity: "!!!" + Band.popularity - }).run_sync() + await Band.update( + { + Band.popularity: "!!!" + Band.popularity + }, + force=True + ) +Date / Timestamp / Timestamptz / Interval columns +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can add or substract a :class:`timedelta ` to any of +these columns. + +For example, if we have a ``Concert`` table, and we want each concert to start +one day later, we can simply do this: + +.. code-block:: python + + await Concert.update( + { + Concert.starts: Concert.starts + datetime.timedelta(days=1) + }, + force=True + ) + +Likewise, we can decrease the values by 1 day: + +.. code-block:: python + + await Concert.update( + { + Concert.starts: Concert.starts - datetime.timedelta(days=1) + }, + force=True + ) + +Array columns +~~~~~~~~~~~~~ + +You can append values to an array (Postgres only). See :meth:`cat `. -You can currently only combine two values together at a time. +What about null values? +~~~~~~~~~~~~~~~~~~~~~~~ + +If we have a table with a nullable column: + +.. code-block:: python + + class Band(Table): + name = Varchar(null=True) + +Any rows with a value of null aren't modified by an update: + +.. code-block:: python + + >>> await Band.insert(Band(name="Pythonistas"), Band(name=None)) + >>> await Band.update( + ... { + ... Band.name: Band.name + '!!!' + ... }, + ... force=True + ... ) + >>> await Band.select() + # Note how the second row's name value is still `None`: + [{'id': 1, 'name': 'Pythonistas!!!'}, {'id': 2, 'name': None}] + +It's more efficient to exclude any rows with a value of null using a +:ref:`where clause `: + +.. code-block:: python + + await Band.update( + { + Band.name + '!!!' + }, + force=True + ).where( + Band.name.is_not_null() + ) + +------------------------------------------------------------------------------- + +Kwarg values +------------ + +Rather than passing in a dictionary of values, you can use kwargs instead if +you prefer: + +.. code-block:: python + + await Band.update( + name='Pythonistas 2' + ).where( + Band.name == 'Pythonistas' + ) + +------------------------------------------------------------------------------- Query clauses ------------- +returning +~~~~~~~~~ + +See :ref:`Returning`. + + where ~~~~~ diff --git a/docs/src/piccolo/schema/advanced.rst b/docs/src/piccolo/schema/advanced.rst index 12c900222..0e4f3db0f 100644 --- a/docs/src/piccolo/schema/advanced.rst +++ b/docs/src/piccolo/schema/advanced.rst @@ -3,11 +3,65 @@ Advanced ======== +.. _Schemas: + +Schemas +------- + +Postgres and CoackroachDB have a concept called **schemas**. + +It's a way of grouping the tables in a database. To learn more: + +* `Postgres docs `_ +* `CockroachDB docs `_ + +To specify a table's schema, do the following: + +.. code-block:: python + + class Band(Table, schema="music"): + ... + + # The table will be created in the `music` schema. + # The music schema will also be created if it doesn't already exist. + >>> await Band.create_table() + +If the ``schema`` argument isn't specified, then the table is created in the +``public`` schema. + +Migration support +~~~~~~~~~~~~~~~~~ + +Schemas are fully supported in :ref:`database migrations `. +For example, if we change the ``schema`` argument: + +.. code-block:: python + + class Band(Table, schema="music_2"): + ... + +Then create an automatic migration and run it, then the table will be moved to +the new schema: + +.. code-block:: bash + + >>> piccolo migrations new my_app --auto + >>> piccolo migrations forwards my_app + +``SchemaManager`` +~~~~~~~~~~~~~~~~~ + +The :class:`SchemaManager ` class is used +internally by Piccolo to interact with schemas. You may find it useful if you +want to write a script to interact with schemas (create / delete / list etc). + +------------------------------------------------------------------------------- + Readable -------- Sometimes Piccolo needs a succinct representation of a row - for example, when -displaying a link in the Piccolo Admin GUI (see :ref:`Ecosystem`). Rather than +displaying a link in the :ref:`Piccolo Admin `. Rather than just displaying the row ID, we can specify something more user friendly using ``Readable``. @@ -32,7 +86,7 @@ tooling - you can also use it your own queries. .. code-block:: python - Band.select(Band.get_readable()).run_sync() + await Band.select(Band.get_readable()) Here is an example of a more complex ``Readable``. @@ -56,7 +110,7 @@ Table Tags ---------- ``Table`` subclasses can be given tags. The tags can be used for filtering, -for example with ``table_finder`` (see :ref:`TableFinder`). +for example with :ref:`table_finder `. .. code-block:: python @@ -89,7 +143,7 @@ use mixins to reduce the amount of repetition. Choices ------- -You can specify choices for a column, using Python's ``Enum`` support. +You can specify choices for a column, using Python's :class:`Enum ` support. .. code-block:: python @@ -111,9 +165,9 @@ We can then use the ``Enum`` in our queries. .. code-block:: python - >>> Shirt(size=Shirt.Size.large).save().run_sync() + >>> await Shirt(size=Shirt.Size.large).save() - >>> Shirt.select().run_sync() + >>> await Shirt.select() [{'id': 1, 'size': 'l'}] Note how the value stored in the database is the ``Enum`` value (in this case ``'l'``). @@ -123,12 +177,12 @@ where a query requires a value. .. code-block:: python - >>> Shirt.insert( - >>> Shirt(size=Shirt.Size.small), - >>> Shirt(size=Shirt.Size.medium) - >>> ).run_sync() + >>> await Shirt.insert( + ... Shirt(size=Shirt.Size.small), + ... Shirt(size=Shirt.Size.medium) + ... ) - >>> Shirt.select().where(Shirt.size == Shirt.Size.small).run_sync() + >>> await Shirt.select().where(Shirt.size == Shirt.Size.small) [{'id': 1, 'size': 's'}] Advantages @@ -136,6 +190,189 @@ Advantages By using choices, you get the following benefits: - * Signalling to other programmers what values are acceptable for the column. - * Improved storage efficiency (we can store ``'l'`` instead of ``'large'``). - * Piccolo Admin support +* Signalling to other programmers what values are acceptable for the column. +* Improved storage efficiency (we can store ``'l'`` instead of ``'large'``). +* Piccolo Admin support + +``Array`` columns +~~~~~~~~~~~~~~~~~ + +You can also use choices with :class:`Array ` +columns. + +.. code-block:: python + + class Ticket(Table): + class Extras(str, enum.Enum): + drink = "drink" + snack = "snack" + program = "program" + + extras = Array(Varchar(), choices=Extras) + +Note how you pass ``choices`` to ``Array``, and not the ``base_column``: + +.. code-block:: python + + # CORRECT: + Array(Varchar(), choices=Extras) + + # INCORRECT: + Array(Varchar(choices=Extras)) + +We can then use the ``Enum`` in our queries: + +.. code-block:: python + + >>> await Ticket.insert( + ... Ticket(extras=[Extras.drink, Extras.snack]), + ... Ticket(extras=[Extras.program]), + ... ) + + +------------------------------------------------------------------------------- + +Reflection +---------- + +This is a very advanced feature, which is only required for specialist use +cases. Currently, just Postgres is supported. + +Instead of writing your ``Table`` definitions in a ``tables.py`` file, Piccolo +can dynamically create them at run time, by inspecting the database. These +``Table`` classes are then stored in memory, using a singleton object called +``TableStorage``. + +Some example use cases: + +* You have a very dynamic database, where new tables are being created + constantly, so updating a ``tables.py`` is impractical. +* You use Piccolo on the command line to explore databases. + +Full reflection +~~~~~~~~~~~~~~~ + +Here's an example, where we reflect the entire schema: + +.. code-block:: python + + from piccolo.table_reflection import TableStorage + + storage = TableStorage() + await storage.reflect(schema_name="music") + +``Table`` objects are accessible from ``TableStorage.tables``: + +.. code-block:: python + + >>> storage.tables + {"music.Band": , ... } + + >>> Band = storage.tables["music.Band"] + +Then you can use them like your normal ``Table`` classes: + +.. code-block:: python + + >>> await Band.select() + [{'id': 1, 'name': 'Pythonistas', 'manager': 1}, ...] + + +Partial reflection +~~~~~~~~~~~~~~~~~~ + +Full schema reflection can be a heavy process based on the size of your schema. +You can use ``include``, ``exclude`` and ``keep_existing`` parameters of +the ``reflect`` method to limit the overhead dramatically. + +Only reflect the needed table(s): + +.. code-block:: python + + from piccolo.table_reflection import TableStorage + + storage = TableStorage() + await storage.reflect(schema_name="music", include=['band', ...]) + +Exclude table(s): + +.. code-block:: python + + await storage.reflect(schema_name="music", exclude=['band', ...]) + +If you set ``keep_existing=True``, only new tables on the database will be +reflected and the existing tables in ``TableStorage`` will be left intact. + +.. code-block:: python + + await storage.reflect(schema_name="music", keep_existing=True) + +get_table +~~~~~~~~~ + +``TableStorage`` has a helper method named ``get_table``. If the table is +already present in the ``TableStorage``, this will return it and if the table +is not present, it will be reflected and returned. + +.. code-block:: python + + Band = storage.get_table(tablename='band') + +.. hint:: Reflection will automatically create ``Table`` classes for referenced + tables too. For example, if ``Table1`` references ``Table2``, then + ``Table2`` will automatically be added to ``TableStorage``. + +------------------------------------------------------------------------------- + +How to create custom column types +--------------------------------- + +Sometimes, the column types shipped with Piccolo don't meet your requirements, and you +will need to define your own column types. + +Generally there are two ways to define your own column types: + +* Create a subclass of an existing column type; or +* Directly subclass the :ref:`Column ` class. + +Try to use the first method whenever possible because it is more straightforward and +can often save you some work. Otherwise, subclass :ref:`Column `. + +**Example** + +In this example, we create a column type called ``MyColumn``, which is fundamentally +an ``Integer`` type but has a custom attribute ``custom_attr``: + +.. code-block:: python + + from piccolo.columns import Integer + + class MyColumn(Integer): + def __init__(self, *args, custom_attr: str = '', **kwargs): + self.custom_attr = custom_attr + super().__init__(*args, **kwargs) + + @property + def column_type(self): + return 'INTEGER' + +.. hint:: It is **important** to specify the ``column_type`` property, which + tells the database engine the **actual** storage type of the custom + column. + +Now we can use ``MyColumn`` in our table: + +.. code-block:: python + + from piccolo.table import Table + + class MyTable(Table): + my_col = MyColumn(custom_attr='foo') + ... + +And later we can retrieve the value of the attribute: + +.. code-block:: python + + >>> MyTable.my_col.custom_attr + 'foo' diff --git a/docs/src/piccolo/schema/column_types.rst b/docs/src/piccolo/schema/column_types.rst index ce3f72424..bdf2258bf 100644 --- a/docs/src/piccolo/schema/column_types.rst +++ b/docs/src/piccolo/schema/column_types.rst @@ -1,7 +1,5 @@ .. _ColumnTypes: -.. currentmodule:: piccolo.columns.column_types - ############ Column Types ############ @@ -13,7 +11,10 @@ Column Types Column ****** +.. currentmodule:: piccolo.columns.base + .. autoclass:: Column + :noindex: ------------------------------------------------------------------------------- @@ -21,6 +22,8 @@ Column Bytea ***** +.. currentmodule:: piccolo.columns.column_types + .. autoclass:: Bytea .. hint:: There is also a ``Blob`` column type, which is an alias for @@ -54,6 +57,18 @@ BigInt .. autoclass:: BigInt +========= +BigSerial +========= + +.. autoclass:: BigSerial + +================ +Double Precision +================ + +.. autoclass:: DoublePrecision + ======= Integer ======= @@ -78,6 +93,12 @@ Real .. hint:: There is also a ``Float`` column type, which is an alias for ``Real``. +====== +Serial +====== + +.. autoclass:: Serial + ======== SmallInt ======== @@ -116,7 +137,13 @@ Varchar .. autoclass:: Varchar -------------------------------------------------------------------------------- +===== +Email +===== + +.. autoclass:: Email + +-------------------------------------------------------------------------- **** Time @@ -162,48 +189,219 @@ Storing JSON can be useful in certain situations, for example - raw API responses, data from a Javascript app, and for storing data with an unknown or changing schema. -==== -JSON -==== +==================== +``JSON`` / ``JSONB`` +==================== .. autoclass:: JSON -===== -JSONB -===== - .. autoclass:: JSONB -arrow -===== +=========== +Serialising +=========== + +Piccolo automatically converts Python values into JSON strings: + +.. code-block:: python + + studio = RecordingStudio( + name="Abbey Road", + facilities={"restaurant": True, "mixing_desk": True} # Automatically serialised + ) + await studio.save() + +You can also pass in a JSON string if you prefer: + +.. code-block:: python + + studio = RecordingStudio( + name="Abbey Road", + facilities='{"restaurant": true, "mixing_desk": true}' + ) + await studio.save() + +============= +Deserialising +============= + +The contents of a ``JSON`` / ``JSONB`` column are returned as a string by +default: + +.. code-block:: python + + >>> await RecordingStudio.select(RecordingStudio.facilities) + [{facilities: '{"restaurant": true, "mixing_desk": true}'}] + +However, we can ask Piccolo to deserialise the JSON automatically (see :ref:`load_json`): + +.. code-block:: python + + >>> await RecordingStudio.select( + ... RecordingStudio.facilities + ... ).output( + ... load_json=True + ... ) + [facilities: {"restaurant": True, "mixing_desk": True}}] + +With ``objects`` queries, we can modify the returned JSON, and then save it: + +.. code-block:: python + + studio = await RecordingStudio.objects().get( + RecordingStudio.name == 'Abbey Road' + ).output(load_json=True) + + studio['facilities']['restaurant'] = False + await studio.save() + +================ +Getting elements +================ + +``JSON`` and ``JSONB`` columns have an ``arrow`` method (representing the +``->`` operator in Postgres), which is useful for retrieving a child element +from the JSON data. + +.. note:: Postgres and CockroachDB only. + +``select`` queries +================== + +If we have the following JSON stored in the ``RecordingStudio.facilities`` +column: + +.. code-block:: json + + { + "instruments": { + "drum_kits": 2, + "electric_guitars": 10 + }, + "restaurant": true, + "technicians": [ + { + "name": "Alice Jones" + }, + { + "name": "Bob Williams" + } + ] + } + +We can retrieve the ``restaurant`` value from the JSON object: -``JSONB`` columns have an ``arrow`` function, which is useful for retrieving -a subset of the JSON data, and for filtering in a where clause. +.. code-block:: python + + >>> await RecordingStudio.select( + ... RecordingStudio.facilities.arrow('restaurant') + ... .as_alias('restaurant') + ... ).output(load_json=True) + [{'restaurant': True}, ...] + +As a convenience, you can use square brackets, instead of calling ``arrow`` +explicitly: + +.. code-block:: python + + >>> await RecordingStudio.select( + ... RecordingStudio.facilities['restaurant'] + ... .as_alias('restaurant') + ... ).output(load_json=True) + [{'restaurant': True}, ...] + +You can drill multiple levels deep by calling ``arrow`` multiple times (or +alternatively use the :ref:`from_path` method - see below). + +Here we fetch the number of drum kits that the recording studio has: + +.. code-block:: python + + >>> await RecordingStudio.select( + ... RecordingStudio.facilities["instruments"]["drum_kits"] + ... .as_alias("drum_kits") + ... ).output(load_json=True) + [{'drum_kits': 2}, ...] + +If you have a JSON object which consists of arrays and objects, then you can +navigate the array elements by passing in an integer to ``arrow``. + +Here we fetch the first technician from the array: + +.. code-block:: python + + >>> await RecordingStudio.select( + ... RecordingStudio.facilities["technicians"][0]["name"] + ... .as_alias("technician_name") + ... ).output(load_json=True) + + [{'technician_name': 'Alice Jones'}, ...] + +``where`` clauses +================= + +The ``arrow`` operator can also be used for filtering in a where clause: .. code-block:: python - # Example schema: - class Booking(Table): - data = JSONB() + >>> await RecordingStudio.select(RecordingStudio.name).where( + ... RecordingStudio.facilities['mixing_desk'].eq(True) + ... ) + [{'name': 'Abbey Road'}] + +.. _from_path: + +============= +``from_path`` +============= + +This works the same as ``arrow`` but is more optimised if you need to return +part of a highly nested JSON structure. + +.. code-block:: python + + >>> await RecordingStudio.select( + ... RecordingStudio.facilities.from_path([ + ... "technicians", + ... 0, + ... "name" + ... ]).as_alias("technician_name") + ... ).output(load_json=True) + + [{'technician_name': 'Alice Jones'}, ...] + +============= +Handling null +============= + +When assigning a value of ``None`` to a ``JSON`` or ``JSONB`` column, this is +treated as null in the database. + +.. code-block:: python - Booking.create_table().run_sync() + await RecordingStudio(name="ABC Studios", facilities=None).save() - # Example data: - Booking.insert( - Booking(data='{"name": "Alison"}'), - Booking(data='{"name": "Bob"}') - ).run_sync() + >>> await RecordingStudio.select( + ... RecordingStudio.facilities + ... ).where( + ... RecordingStudio.name == "ABC Studios" + ... ) + [{'facilities': None}] - # Example queries - >>> Booking.select( - >>> Booking.id, Booking.data.arrow('name').as_alias('name') - >>> ).run_sync() - [{'id': 1, 'name': '"Alison"'}, {'id': 2, 'name': '"Bob"'}] - >>> Booking.select(Booking.id).where( - >>> Booking.data.arrow('name') == '"Alison"' - >>> ).run_sync() - [{'id': 1}] +If instead you want to store JSON null in the database, assign a value of ``'null'`` +instead. + +.. code-block:: python + + await RecordingStudio(name="ABC Studios", facilities='null').save() + + >>> await RecordingStudio.select( + ... RecordingStudio.facilities + ... ).where( + ... RecordingStudio.name == "ABC Studios" + ... ) + [{'facilities': 'null'}] ------------------------------------------------------------------------------- @@ -211,8 +409,8 @@ a subset of the JSON data, and for filtering in a where clause. Array ***** -Arrays of data can be stored, which can be useful when you want store lots of -values without using foreign keys. +Arrays of data can be stored, which can be useful when you want to store lots +of values without using foreign keys. .. autoclass:: Array @@ -228,8 +426,45 @@ any .. automethod:: Array.any +======= +not_any +======= + +.. automethod:: Array.not_any + === all === .. automethod:: Array.all + +=== +cat +=== + +.. automethod:: Array.cat + +====== +remove +====== + +.. automethod:: Array.remove + +====== +append +====== + +.. automethod:: Array.append + + +======= +prepend +======= + +.. automethod:: Array.prepend + +======= +replace +======= + +.. automethod:: Array.replace diff --git a/docs/src/piccolo/schema/defining.rst b/docs/src/piccolo/schema/defining.rst index 4da56c8ef..defcaf4b3 100644 --- a/docs/src/piccolo/schema/defining.rst +++ b/docs/src/piccolo/schema/defining.rst @@ -3,8 +3,8 @@ Defining a Schema ================= -The schema is usually defined within the ``tables.py`` file of your Piccolo -app (see :ref:`PiccoloApps`). +The schema is usually defined within the ``tables.py`` file of your +:ref:`Piccolo app `. This reflects the tables in your database. Each table consists of several columns. Here's a very simple schema: @@ -19,7 +19,11 @@ columns. Here's a very simple schema: class Band(Table): name = Varchar(length=100) -For a full list of columns, see :ref:`ColumnTypes`. +For a full list of columns, see :ref:`column types `. + +.. hint:: If you're using an existing database, see Piccolo's + :ref:`auto schema generation command`, which will save you some + time. ------------------------------------------------------------------------------- @@ -27,7 +31,7 @@ Primary Key ----------- Piccolo tables are automatically given a primary key column called ``id``, -which is an auto incrementing integer. +which is an auto incrementing integer (a ``Serial(primary_key=True)`` column). There is currently experimental support for specifying a custom primary key column. For example: diff --git a/docs/src/piccolo/schema/images/m2m.png b/docs/src/piccolo/schema/images/m2m.png new file mode 100644 index 000000000..4831569ba Binary files /dev/null and b/docs/src/piccolo/schema/images/m2m.png differ diff --git a/docs/src/piccolo/schema/index.rst b/docs/src/piccolo/schema/index.rst index e1450745f..ec9b887e6 100644 --- a/docs/src/piccolo/schema/index.rst +++ b/docs/src/piccolo/schema/index.rst @@ -8,4 +8,6 @@ The schema is how you define your database tables, columns and relationships. ./defining ./column_types + ./m2m + ./one_to_one ./advanced diff --git a/docs/src/piccolo/schema/m2m.rst b/docs/src/piccolo/schema/m2m.rst new file mode 100644 index 000000000..b7c44188a --- /dev/null +++ b/docs/src/piccolo/schema/m2m.rst @@ -0,0 +1,165 @@ +.. currentmodule:: piccolo.columns.m2m + +### +M2M +### + +.. note:: + + There is a `video tutorial on YouTube `_. + +Sometimes in database design you need `many-to-many (M2M) `_ +relationships. + +For example, we might have our ``Band`` table, and want to describe which genres of music +each band belongs to (e.g. rock and electronic). As each band can have multiple genres, a ``ForeignKey`` +on the ``Band`` table won't suffice. Our options are using an ``Array`` / ``JSON`` / ``JSONB`` +column, or using an ``M2M`` relationship. + +Postgres and SQLite don't natively support ``M2M`` relationships - we create +them using a joining table which has foreign keys to each of the related tables +(in our example, ``Genre`` and ``Band``). + +.. image:: ./images/m2m.png + :width: 500 + :align: center + +We create it in Piccolo like this: + +.. code-block:: python + + from piccolo.columns.column_types import ( + ForeignKey, + LazyTableReference, + Varchar + ) + from piccolo.columns.m2m import M2M + from piccolo.table import Table + + + class Band(Table): + name = Varchar() + genres = M2M(LazyTableReference("GenreToBand", module_path=__name__)) + + + class Genre(Table): + name = Varchar() + bands = M2M(LazyTableReference("GenreToBand", module_path=__name__)) + + + # This is our joining table: + class GenreToBand(Table): + band = ForeignKey(Band) + genre = ForeignKey(Genre) + + +.. note:: + We use :class:`LazyTableReference ` + because when Python evaluates ``Band`` and ``Genre``, the ``GenreToBand`` + class doesn't exist yet. + +By using ``M2M`` it unlocks some powerful and convenient features. + +------------------------------------------------------------------------------- + +Select queries +============== + +If we want to select each band, along with a list of genres that they belong to, +we can do this: + +.. code-block:: python + + >>> await Band.select(Band.name, Band.genres(Genre.name, as_list=True)) + [ + {"name": "Pythonistas", "genres": ["Rock", "Folk"]}, + {"name": "Rustaceans", "genres": ["Folk"]}, + {"name": "C-Sharps", "genres": ["Rock", "Classical"]}, + ] + +You can request whichever column you like from the related table: + +.. code-block:: python + + >>> await Band.select(Band.name, Band.genres(Genre.id, as_list=True)) + [ + {"name": "Pythonistas", "genres": [1, 2]}, + {"name": "Rustaceans", "genres": [2]}, + {"name": "C-Sharps", "genres": [1, 3]}, + ] + +You can also request multiple columns from the related table: + +.. code-block:: python + + >>> await Band.select(Band.name, Band.genres(Genre.id, Genre.name)) + [ + { + 'name': 'Pythonistas', + 'genres': [ + {'id': 1, 'name': 'Rock'}, + {'id': 2, 'name': 'Folk'} + ] + }, + ... + ] + +If you omit the columns argument, then all of the columns are returned. + +.. code-block:: python + + >>> await Band.select(Band.name, Band.genres()) + [ + { + 'name': 'Pythonistas', + 'genres': [ + {'id': 1, 'name': 'Rock'}, + {'id': 2, 'name': 'Folk'} + ] + }, + ... + ] + + +As we defined ``M2M`` on the ``Genre`` table too, we can get each band in a +given genre: + +.. code-block:: python + + >>> await Genre.select(Genre.name, Genre.bands(Band.name, as_list=True)) + [ + {"name": "Rock", "bands": ["Pythonistas", "C-Sharps"]}, + {"name": "Folk", "bands": ["Pythonistas", "Rustaceans"]}, + {"name": "Classical", "bands": ["C-Sharps"]}, + ] + +------------------------------------------------------------------------------- + +Objects queries +=============== + +Piccolo makes it easy working with objects and ``M2M`` relationship. + + +add_m2m +------- + +.. currentmodule:: piccolo.table + +.. automethod:: Table.add_m2m + :noindex: + +get_m2m +------- + +.. automethod:: Table.get_m2m + :noindex: + +remove_m2m +---------- + +.. automethod:: Table.remove_m2m + :noindex: + +.. hint:: All of these methods can be run synchronously as well - for example, + ``band.get_m2m(Band.genres).run_sync()``. diff --git a/docs/src/piccolo/schema/one_to_one.rst b/docs/src/piccolo/schema/one_to_one.rst new file mode 100644 index 000000000..f35433189 --- /dev/null +++ b/docs/src/piccolo/schema/one_to_one.rst @@ -0,0 +1,85 @@ +.. _OneToOne: + +One to One +========== + +Schema +------ + +A one to one relationship is basically just a foreign key with a unique +constraint. In Piccolo, you can do it like this: + +.. code-block:: python + + from piccolo.table import Table + from piccolo.columns import ForeignKey, Varchar, Text + + class Band(Table): + name = Varchar() + + class FanClub(Table): + band = ForeignKey(Band, unique=True) # <- Note the unique constraint + address = Text() + +Queries +------- + +Getting a related object +~~~~~~~~~~~~~~~~~~~~~~~~ + +If we have a ``Band`` object: + +.. code-block:: python + + band = await Band.objects().where(Band.name == "Pythonistas").first() + +To get the associated ``FanClub`` object, you could do this: + +.. code-block:: python + + fan_club = await FanClub.objects().where(FanClub.band == band).first() + +Or alternatively, using ``get_related``: + +.. code-block:: python + + fan_club = await band.get_related(Band.id.join_on(FanClub.band)) + +Instead of using ``join_on``, you can use ``reverse`` to traverse the foreign +key backwards if you prefer: + +.. code-block:: python + + fan_club = await band.get_related(FanClub.band.reverse()) + +Select +~~~~~~ + +If doing a select query, and you want data from the related table: + +.. code-block:: python + + >>> await Band.select( + ... Band.name, + ... Band.id.join_on(FanClub.band).address.as_alias("address") + ... ) + [{'name': 'Pythonistas', 'address': '1 Flying Circus, UK'}, ...] + +Where +~~~~~ + +If you want to filter by related tables in the ``where`` clause: + +.. code-block:: python + + >>> await Band.select( + ... Band.name, + ... ).where(Band.id.join_on(FanClub.band).address.like("%Flying%")) + [{'name': 'Pythonistas'}] + +Source +------ + +.. currentmodule:: piccolo.columns.column_types + +.. automethod:: ForeignKey.reverse diff --git a/docs/src/piccolo/serialization/index.rst b/docs/src/piccolo/serialization/index.rst new file mode 100644 index 000000000..446e6f20f --- /dev/null +++ b/docs/src/piccolo/serialization/index.rst @@ -0,0 +1,302 @@ +Serialization +============= + +Piccolo uses `Pydantic `_ internally +to serialize and deserialize data. + +Using ``create_pydantic_model`` you can easily create Pydantic models for your +application. + +------------------------------------------------------------------------------- + +``create_pydantic_model`` +------------------------- + +Using ``create_pydantic_model`` we can easily create a `Pydantic model `_ +from a Piccolo ``Table``. + +Using this example schema: + +.. code-block:: python + + from piccolo.columns import ForeignKey, Integer, Varchar + from piccolo.table import Table + + class Manager(Table): + name = Varchar() + + class Band(Table): + name = Varchar(length=100) + manager = ForeignKey(Manager) + popularity = Integer() + +Creating a Pydantic model is as simple as: + +.. code-block:: python + + from piccolo.utils.pydantic import create_pydantic_model + + BandModel = create_pydantic_model(Band) + +We can then create model instances from data we fetch from the database: + +.. code-block:: python + + # If using objects: + band = await Band.objects().get(Band.name == 'Pythonistas') + model = BandModel(**band.to_dict()) + + # If using select: + band = await Band.select().where(Band.name == 'Pythonistas').first() + model = BandModel(**band) + + >>> model.name + 'Pythonistas' + +You have several options for configuring the model, as shown below. + +``include_columns`` / ``exclude_columns`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If we want to exclude the ``popularity`` column from the ``Band`` table: + +.. code-block:: python + + BandModel = create_pydantic_model(Band, exclude_columns=(Band.popularity,)) + +Conversely, if you only wanted the ``popularity`` column: + +.. code-block:: python + + BandModel = create_pydantic_model(Band, include_columns=(Band.popularity,)) + +``nested`` +~~~~~~~~~~ + +Another great feature is ``nested=True``. For each ``ForeignKey`` in the +Piccolo ``Table``, the Pydantic model will contain a sub model for the related +table. + +For example: + +.. code-block:: python + + BandModel = create_pydantic_model(Band, nested=True) + +If we were to write ``BandModel`` by hand instead, it would look like this: + +.. code-block:: python + + from pydantic import BaseModel + + class ManagerModel(BaseModel): + name: str + + class BandModel(BaseModel): + name: str + manager: ManagerModel + popularity: int + +But with ``nested=True`` we can achieve this with one line of code. + +To populate a nested Pydantic model with data from the database: + +.. code-block:: python + + # If using objects: + band = await Band.objects(Band.manager).get(Band.name == 'Pythonistas') + model = BandModel(**band.to_dict()) + + # If using select: + band = await Band.select( + Band.all_columns(), + Band.manager.all_columns() + ).where( + Band.name == 'Pythonistas' + ).first().output( + nested=True + ) + model = BandModel(**band) + + >>> model.manager.name + 'Guido' + +.. note:: + + There is a `video tutorial on YouTube `_. + +``include_default_columns`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Sometimes you'll want to include the Piccolo ``Table``'s primary key column in +the generated Pydantic model. For example, in a ``GET`` endpoint, we usually +want to include the ``id`` in the response: + +.. code-block:: javascript + + // GET /api/bands/1/ + // Response: + {"id": 1, "name": "Pythonistas", "popularity": 1000} + +Other times, you won't want the Pydantic model to include the primary key +column. For example, in a ``POST`` endpoint, when using a Pydantic model to +serialise the payload, we don't expect the user to pass in an ``id`` value: + +.. code-block:: javascript + + // POST /api/bands/ + // Payload: + {"name": "Pythonistas", "popularity": 1000} + +By default the primary key column isn't included - you can add it using: + +.. code-block:: python + + BandModel = create_pydantic_model(Band, include_default_columns=True) + +``pydantic_config`` +~~~~~~~~~~~~~~~~~~~ + +.. hint:: We used to have a ``pydantic_config_class`` argument in Piccolo prior + to v1, but it has been replaced with ``pydantic_config`` due to changes in + Pydantic v2. + +You can specify a Pydantic ``ConfigDict`` to use as the base for the Pydantic +model's config (see `docs `_). + +For example, let's set the ``extra`` parameter to tell pydantic how to treat +extra fields (that is, fields that would not otherwise be in the generated model). +The allowed values are: + +* ``'ignore'`` (default): silently ignore extra fields +* ``'allow'``: accept the extra fields and assigns them to the model +* ``'forbid'``: fail validation if extra fields are present + +So if we want to disallow extra fields, we can do: + +.. code-block:: python + + from pydatic.config import ConfigDict + + config: ConfigDict = { + "extra": "forbid" + } + + model = create_pydantic_model( + table=MyTable, + pydantic_config=config + ) + + +Required fields +~~~~~~~~~~~~~~~ + +You can specify which fields are required using the ``required`` +argument of :class:`Column `. For example: + +.. code-block:: python + + class Band(Table): + name = Varchar(required=True) + + BandModel = create_pydantic_model(Band) + + # Omitting the field raises an error: + >>> BandModel() + ValidationError - name field required + +You can override this behaviour using the ``all_optional`` argument. An example +use case is when you have a model which is used for filtering, then you'll want +all fields to be optional. + +.. code-block:: python + + class Band(Table): + name = Varchar(required=True) + + BandFilterModel = create_pydantic_model( + Band, + all_optional=True, + model_name='BandFilterModel', + ) + + # This no longer raises an exception: + >>> BandModel() + +Subclassing the model +~~~~~~~~~~~~~~~~~~~~~ + +If the generated model doesn't perfectly fit your needs, you can subclass it to +add additional fields, and to override existing fields. + +.. code-block:: python + + class Band(Table): + name = Varchar(required=True) + + BandModel = create_pydantic_model(Band) + + class CustomBandModel(BandModel): + genre: str + + >>> CustomBandModel(name="Pythonistas", genre="Rock") + +Or even simpler still: + +.. code-block:: python + + class BandModel(create_pydantic_model(Band)): + genre: str + + +Avoiding type warnings +~~~~~~~~~~~~~~~~~~~~~~ + +Some linters will complain if you use variables in type annotations: + +.. code-block:: python + + BandModel = create_pydantic_model(Band) + + + def my_function(band: BandModel): # Variable not allowed in type expression! + ... + + +The fix is really simple: + +.. code-block:: python + + # We now have a class instead of a variable: + class BandModel(create_pydantic_model(Band)): + ... + + + def my_function(band: BandModel): + ... + +Source +~~~~~~ + +.. currentmodule:: piccolo.utils.pydantic + +.. autofunction:: create_pydantic_model + +.. hint:: A good place to see ``create_pydantic_model`` in action is `PiccoloCRUD `_, + as it uses ``create_pydantic_model`` extensively to create Pydantic models + from Piccolo tables. + +------------------------------------------------------------------------------- + +FastAPI template +---------------- + +Piccolo's FastAPI template uses ``create_pydantic_model`` to create serializers. + +To create a new FastAPI app using Piccolo, simply use: + +.. code-block:: bash + + piccolo asgi new + +See the :ref:`ASGI docs ` for more details. diff --git a/docs/src/piccolo/testing/index.rst b/docs/src/piccolo/testing/index.rst new file mode 100644 index 000000000..84eabff83 --- /dev/null +++ b/docs/src/piccolo/testing/index.rst @@ -0,0 +1,243 @@ +Testing +======= + +Piccolo provides a few tools to make testing easier. + +------------------------------------------------------------------------------- + +Test runner +----------- + +Piccolo ships with a handy command for running your unit tests using pytest. +See the :ref:`tester app`. + +You can put your test files anywhere you like, but a good place is in a ``tests`` +folder within your Piccolo app. The test files should be named like +``test_*. py`` or ``*_test.py`` for pytest to recognise them. + +------------------------------------------------------------------------------- + +Model Builder +------------- + +When writing unit tests, it's usually required to have some data seeded into +the database. You can build and save the records manually or use +:class:`ModelBuilder ` to generate +random records for you. + +This way you can randomize the fields you don't care about and specify +important fields explicitly and reduce the amount of manual work required. +``ModelBuilder`` currently supports all Piccolo column types and features. + +Let's say we have the following schema: + +.. code-block:: python + + from piccolo.columns import ForeignKey, Varchar + + class Manager(Table): + name = Varchar(length=50) + + class Band(Table): + name = Varchar(length=50) + manager = ForeignKey(Manager, null=True) + +You can build a random ``Band`` which will also build and save a random +``Manager``: + +.. code-block:: python + + from piccolo.testing.model_builder import ModelBuilder + + # Band instance with random values persisted: + band = await ModelBuilder.build(Band) + +.. note:: ``ModelBuilder.build(Band)`` persists the record into the database by default. + +You can also run it synchronously if you prefer: + +.. code-block:: python + + manager = ModelBuilder.build_sync(Manager) + + +To specify any attribute, pass the ``defaults`` dictionary to the ``build`` method: + +.. code-block:: python + + manager = ModelBuilder.build(Manager) + + # Using table columns: + band = await ModelBuilder.build( + Band, + defaults={Band.name: "Guido", Band.manager: manager} + ) + + # Or using strings as keys: + band = await ModelBuilder.build( + Band, + defaults={"name": "Guido", "manager": manager} + ) + +To build objects without persisting them into the database: + +.. code-block:: python + + band = await ModelBuilder.build(Band, persist=False) + +To build objects with minimal attributes, leaving nullable fields empty: + +.. code-block:: python + + # Leaves manager empty: + band = await ModelBuilder.build(Band, minimal=True) + +------------------------------------------------------------------------------- + +Creating the test schema +------------------------ + +When running your unit tests, you usually start with a blank test database, +create the tables, and then install test data. + +To create the tables, there are a few different approaches you can take. + +``create_db_tables`` / ``drop_db_tables`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Here we use :func:`create_db_tables ` and +:func:`drop_db_tables ` to create and drop the +tables. + +.. note:: + The sync equivalents are :func:`create_db_tables_sync ` + and :func:`drop_db_tables_sync `, if + you need your tests to be synchronous for some reason. + +.. code-block:: python + + from unittest import IsolatedAsyncioTestCase + + from piccolo.table import create_db_tables, drop_db_tables + from piccolo.conf.apps import Finder + + + TABLES = Finder().get_table_classes() + + + class TestApp(IsolatedAsyncioTestCase): + async def setUp(self): + await create_db_tables(*TABLES) + + async def tearDown(self): + await drop_db_tables(*TABLES) + + async def test_app(self): + # Do some testing ... + pass + +You can remove this boiler plate by using +:class:`AsyncTransactionTest `, +which does this for you. + +Run migrations +~~~~~~~~~~~~~~ + +Alternatively, you can run the migrations to setup the schema if you prefer: + +.. code-block:: python + + from unittest import IsolatedAsyncioTestCase + + from piccolo.apps.migrations.commands.backwards import run_backwards + from piccolo.apps.migrations.commands.forwards import run_forwards + + + class TestApp(IsolatedAsyncioTestCase): + async def setUp(self): + await run_forwards("all") + + async def tearDown(self): + await run_backwards("all", auto_agree=True) + + async def test_app(self): + # Do some testing ... + pass + +------------------------------------------------------------------------------- + +Testing async code +------------------ + +There are a few options for testing async code using pytest. + +``run_sync`` +~~~~~~~~~~~~ + +You can call any async code using Piccolo's ``run_sync`` utility: + +.. code-block:: python + + from piccolo.utils.sync import run_sync + + async def get_data(): + ... + + def test_get_data(): + rows = run_sync(get_data()) + assert len(rows) == 1 + +It's preferable to make your tests natively async though. + +``pytest-asyncio`` +~~~~~~~~~~~~~~~~~~ + +If you prefer using pytest's function based tests, then take a look at +`pytest-asyncio `_. Simply +install it using ``pip install pytest-asyncio``, then you can then write tests +like this: + +.. code-block:: python + + async def test_select(): + rows = await MyTable.select() + assert len(rows) == 1 + +``IsolatedAsyncioTestCase`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you prefer class based tests, and are using Python 3.8 or above, then have +a look at :class:`IsolatedAsyncioTestCase ` +from Python's standard library. You can then write tests like this: + +.. code-block:: python + + from unittest import IsolatedAsyncioTestCase + + class MyTest(IsolatedAsyncioTestCase): + async def test_select(self): + rows = await MyTable.select() + assert len(rows) == 1 + +Also look at the ``IsolatedAsyncioTestCase`` subclasses which Piccolo provides +(see :class:`AsyncTransactionTest ` +and :class:`AsyncTableTest ` below). + +------------------------------------------------------------------------------- + +``TestCase`` subclasses +----------------------- + +Piccolo ships with some ``unittest.TestCase`` subclasses which remove +boilerplate code from tests. + +.. currentmodule:: piccolo.testing.test_case + +.. autoclass:: AsyncTransactionTest + :class-doc-from: class + +.. autoclass:: AsyncTableTest + :class-doc-from: class + +.. autoclass:: TableTest + :class-doc-from: class diff --git a/docs/src/piccolo/tutorials/avoiding_circular_imports.rst b/docs/src/piccolo/tutorials/avoiding_circular_imports.rst new file mode 100644 index 000000000..fd587abb6 --- /dev/null +++ b/docs/src/piccolo/tutorials/avoiding_circular_imports.rst @@ -0,0 +1,82 @@ +Avoiding circular imports +========================= + +How Python imports work +----------------------- + +When Python imports a file, it evaluates it from top to bottom. + +With :class:`ForeignKey ` columns we +sometimes have to reference tables lower down in the file (which haven't been +evaluated yet). + +The solutions are: + +* Try and move the referenced table to a different Python file. +* Use :class:`LazyTableReference ` + +Import ``Table`` definitions as early as possible +------------------------------------------------- + +In the entrypoint to your app, at the top of the file, it's recommended to +import your tables. + +.. code-block:: python + + # main.py + from my_app.tables import Manager, Band + +This ensures that the tables are imported, and setup correctly. + +Keep table files focused +------------------------ + +You should try and keep your ``tables.py`` files pretty focused (i.e. +just contain your ``Table`` definitions). + +If you have lots of logic alongside your ``Table`` definitions, it might cause +your ``LazyTableReference`` references to evaluate too soon (causing circular +import errors). An example of this is with +:func:`create_pydantic_model `: + +.. literalinclude:: avoiding_circular_imports_src/tables.py + +Simplify your schema if possible +-------------------------------- + +Even with :class:`LazyTableReference `, +you may run into some problems if your schema is really complicated. + +An example is when you have two tables, and they have foreign keys to each other. + +.. code-block:: python + + class Band(Table): + name = Varchar() + manager = ForeignKey("Manager") + + + class Manager(Table): + name = Varchar() + favourite_band = ForeignKey(Band) + + +Piccolo should be able to create these tables, and query them. However, some +Piccolo tooling may struggle - for example when loading :ref:`fixtures `. + +A joining table can help in these situations: + +.. code-block:: python + + class Band(Table): + name = Varchar() + manager = ForeignKey("Manager") + + + class Manager(Table): + name = Varchar() + + + class ManagerFavouriteBand(Table): + manager = ForeignKey(Manager, unique=True) + band = ForeignKey(Band) diff --git a/docs/src/piccolo/tutorials/avoiding_circular_imports_src/tables.py b/docs/src/piccolo/tutorials/avoiding_circular_imports_src/tables.py new file mode 100644 index 000000000..6d1021deb --- /dev/null +++ b/docs/src/piccolo/tutorials/avoiding_circular_imports_src/tables.py @@ -0,0 +1,22 @@ +# tables.py + +from piccolo.columns import ForeignKey, Varchar +from piccolo.table import Table +from piccolo.utils.pydantic import create_pydantic_model + + +class Band(Table): + name = Varchar() + # This automatically gets converted into a LazyTableReference, because a + # string is passed in: + manager = ForeignKey("Manager") + + +# This is not recommended, as it will cause the LazyTableReference to be +# evaluated before Manager has imported. +# Instead, move this to a separate file, or below Manager. +BandModel = create_pydantic_model(Band) + + +class Manager(Table): + name = Varchar() diff --git a/docs/src/piccolo/tutorials/deployment.rst b/docs/src/piccolo/tutorials/deployment.rst new file mode 100644 index 000000000..3b5352d58 --- /dev/null +++ b/docs/src/piccolo/tutorials/deployment.rst @@ -0,0 +1,87 @@ +Deploying using Docker +====================== + +Docker +------ + +`Docker `_ is a very popular way of deploying +applications, using containers. + +Base image +~~~~~~~~~~ + +Piccolo has several dependencies which are compiled (e.g. asyncpg, orjson), +which is great for performance, but you may run into difficulties when using +Alpine Linux as your base Docker image. Alpine uses a different compiler +toolchain to most Linux distros. + +It's highly recommended to use Debian as your base Docker image. Many Python packages +have prebuilt versions for Debian, meaning you don't have to compile them at +all during install. The result is a much faster build process, and potentially +even a smaller overall Docker image size (the size of Alpine quickly balloons +after you've added all of the compilation dependencies). + +Environment variables +~~~~~~~~~~~~~~~~~~~~~ + +By using environment variables, we can inject the database credentials for +Piccolo. + +Example Dockerfile +~~~~~~~~~~~~~~~~~~ + +This is a very simple Dockerfile, and illustrates the basics: + +.. code-block:: dockerfile + + # Specify the base image: + FROM python:3.12-bookworm + + # Install the pip requirements: + RUN pip install --upgrade pip + ADD app/requirements.txt / + RUN pip install -r /requirements.txt + + # Add the application code: + ADD app /app + + # Environment variables: + ENV PG_HOST=localhost + ENV PG_PORT=5432 + ENV PG_USER=my_database_user + ENV PG_PASSWORD="" + ENV PG_DATABASE=my_database + + CMD ["/usr/local/bin/python", "/app/main.py"] + +We can then modify our :ref:`piccolo_conf.py ` file to use these +environment variables: + +.. code-block:: python + + # piccolo_conf.py + + import os + + DB = PostgresEngine( + config={ + "port": int(os.environ.get("PG_PORT", "5432")), + "user": os.environ.get("PG_USER", "my_database_user"), + "password": os.environ.get("PG_PASSWORD", ""), + "database": os.environ.get("PG_DATABASE", "my_database"), + "host": os.environ.get("PG_HOST", "localhost"), + } + ) + +When we run the container (usually via `Kubernetes `_, +`Docker Compose `_, or similar), +we can specify the database credentials using environment variables, which will +be used by our application. + +Accessing a local Postgres database +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Bear in mind that if you have Postgres running locally on the server (i.e. on +``localhost``), your Docker container won't automatically be able to access it. +You can try Docker's host based networking, or just run Postgres within a +Docker container. diff --git a/docs/src/piccolo/tutorials/fastapi.rst b/docs/src/piccolo/tutorials/fastapi.rst new file mode 100644 index 000000000..0e0c91a61 --- /dev/null +++ b/docs/src/piccolo/tutorials/fastapi.rst @@ -0,0 +1,52 @@ +FastAPI +======= + +`FastAPI `_ is a popular ASGI web framework. The +purpose of this tutorial is to give some hints on how to get started with +Piccolo and FastAPI. + +Piccolo and FastAPI are a great match, and are commonly used together. + +Creating a new project +---------------------- + +Using the ``piccolo asgi new`` command, Piccolo will scaffold a new FastAPI app +for you - simple! + +Pydantic models +--------------- + +FastAPI uses `Pydantic `_ for serialising and +deserialising data. + +Piccolo provides :func:`create_pydantic_model ` +which creates Pydantic models for you based on your Piccolo tables. + +Of course, you can also just define your Pydantic models by hand. + +Transactions +------------ + +Using FastAPI's dependency injection system, we can easily wrap each endpoint +in a transaction. + +.. literalinclude:: fastapi_src/app.py + :emphasize-lines: 19-21,36 + +FastAPI dependencies can be declared at the endpoint, ``APIRouter``, or even +app level. + +``FastAPIWrapper`` +------------------ + +Piccolo API has a powerful utility called +:class:`FastAPIWrapper ` which +generates REST endpoints based on your Piccolo tables, and adds them to FastAPI's +Swagger docs. It's a very productive way of building an API. + +Authentication +-------------- + +`Piccolo API `_ ships with +`authentication middleware `_ +which is compatible with `FastAPI middleware `_. diff --git a/docs/src/piccolo/tutorials/fastapi_src/app.py b/docs/src/piccolo/tutorials/fastapi_src/app.py new file mode 100644 index 000000000..56525287f --- /dev/null +++ b/docs/src/piccolo/tutorials/fastapi_src/app.py @@ -0,0 +1,54 @@ +from fastapi import Depends, FastAPI +from pydantic import BaseModel + +from piccolo.columns.column_types import Varchar +from piccolo.engine.sqlite import SQLiteEngine +from piccolo.table import Table + +DB = SQLiteEngine() + + +class Band(Table, db=DB): + """ + You would usually import this from tables.py + """ + + name = Varchar() + + +async def transaction(): + async with DB.transaction() as transaction: + yield transaction + + +app = FastAPI() + + +@app.get("/bands/", dependencies=[Depends(transaction)]) +async def get_bands(): + return await Band.select() + + +class CreateBandModel(BaseModel): + name: str + + +@app.post("/bands/", dependencies=[Depends(transaction)]) +async def create_band(model: CreateBandModel): + await Band({Band.name: model.name}).save() + + # If an exception is raised then the transaction is rolled back. + raise Exception("Oops") + + +async def main(): + await Band.create_table(if_not_exists=True) + + +if __name__ == "__main__": + import asyncio + + import uvicorn + + asyncio.run(main()) + uvicorn.run(app) diff --git a/docs/src/piccolo/tutorials/index.rst b/docs/src/piccolo/tutorials/index.rst new file mode 100644 index 000000000..e017ad7d5 --- /dev/null +++ b/docs/src/piccolo/tutorials/index.rst @@ -0,0 +1,15 @@ +Tutorials +========= + +These tutorials bring together information from across the documentation, to +help you solve common problems: + +.. toctree:: + :maxdepth: 1 + + ./migrate_existing_project + ./using_sqlite_and_asyncio_effectively + ./deployment + ./fastapi + ./avoiding_circular_imports + ./moving_table_between_apps diff --git a/docs/src/piccolo/tutorials/migrate_existing_project.rst b/docs/src/piccolo/tutorials/migrate_existing_project.rst new file mode 100644 index 000000000..e4810d25a --- /dev/null +++ b/docs/src/piccolo/tutorials/migrate_existing_project.rst @@ -0,0 +1,129 @@ +Migrate an existing project to Piccolo +====================================== + +Introduction +------------ + +If you have an existing project and Postgres database, and you want to use +Piccolo with it, these are the steps you need to take. + +Option 1 - ``piccolo asgi new`` +------------------------------- + +This is the recommended way of creating brand new projects. If this is your +first experience with Piccolo, then it's a good idea to create a test project: + +.. code-block:: bash + + mkdir test_project + cd test_project + piccolo asgi new + +You'll learn a lot about how Piccolo works by looking at the generated code. +You can then copy over the relevant files to your existing project if you like. + +Alternatively, doing it from scratch, you'll need to do the following: + +Option 2 - from scratch +----------------------- + +Create a Piccolo project file +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Create a new ``piccolo_conf.py`` file in the root of your project: + +.. code-block:: bash + + piccolo project new + +This contains your database details, and is used to register Piccolo apps. + +Create a new Piccolo app +~~~~~~~~~~~~~~~~~~~~~~~~ + +The app contains your ``Table`` classes and migrations. Run this command at the +root of your project: + +.. code-block:: bash + + # Replace 'my_app' with whatever you want to call your app + piccolo app new my_app + +Register the new Piccolo app +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Register this new app in ``piccolo_conf.py``. For example: + +.. code-block:: python + + APP_REGISTRY = AppRegistry( + apps=[ + "my_app.piccolo_app", + ] + ) + +While you're at it, make sure the database credentials are correct in +``piccolo_conf.py``. + +Make ``Table`` classes for your current database +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Now, if you run: + +.. code-block:: bash + + piccolo schema generate + +It will output Piccolo ``Table`` classes for your current database. Copy the +output into ``my_app/tables.py``. Double check that everything looks correct. + +In ``my_app/piccolo_app.py`` make sure it's tracking these tables for +migration purposes. + +.. code-block:: python + + from piccolo.conf.apps import AppConfig, table_finder + + APP_CONFIG = AppConfig( + table_classes=table_finder(["my_app.tables"], exclude_imported=True), + ... + ) + +Create an initial migration +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This will create a new file in ``my_app/piccolo_migrations``: + +.. code-block:: bash + + piccolo migrations new my_app --auto + +These tables already exist in the database, as it's an existing project, so +you need to :ref:`fake apply ` this initial migration: + +.. code-block:: bash + + piccolo migrations forwards my_app --fake + +Making queries +~~~~~~~~~~~~~~ + +Now you're basically setup - to make database queries: + +.. code-block:: python + + from my_app.tables import MyTable + + async def my_endpoint(): + data = await MyTable.select() + return data + +Making new migrations +~~~~~~~~~~~~~~~~~~~~~ + +Just modify the files in ``tables.py``, and then run: + +.. code-block:: bash + + piccolo migrations new my_app --auto + piccolo migrations forwards my_app diff --git a/docs/src/piccolo/tutorials/moving_table_between_apps.rst b/docs/src/piccolo/tutorials/moving_table_between_apps.rst new file mode 100644 index 000000000..5d72c52f5 --- /dev/null +++ b/docs/src/piccolo/tutorials/moving_table_between_apps.rst @@ -0,0 +1,82 @@ +Moving a Table Between Piccolo Apps Without Data Loss +====================================================== + +Piccolo ORM makes it easy to manage models within individual apps. But what if you need to move a table (model) from one app to another—say, from ``app_a`` to ``app_b``—without losing your data? + +This tutorial walks you through the safest way to move a table between Piccolo apps using migrations and the ``--fake`` flag. + +Use Case +-------- + +You're working on a project structured with multiple Piccolo apps, and you want to reorganize your models by moving a table (``TableA``) from one app (``app_a``) to another (``app_b``), without affecting the data in your database. + +Prerequisites +------------- + +- Piccolo ORM installed and configured +- Both ``app_a`` and ``app_b`` listed in ``piccolo_conf.py`` under ``PICCOLO_APPS`` +- Basic familiarity with Piccolo migrations + +Step-by-Step Instructions +------------------------- + +1. Remove the Table from ``app_a`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In ``app_a/tables.py``, delete or comment out the ``TableA`` class definition. + +2. Create a Migration in ``app_a`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Run the following command in your terminal: + +.. code-block:: bash + + piccolo migrations new app_a --auto + +This will create a migration that removes the table from ``app_a``. + +3. Fake Apply the Migration +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To prevent the table from actually being dropped from the database, apply the migration using the ``--fake`` flag: + +.. code-block:: bash + + piccolo migrations forwards app_a --fake + +This marks the migration as applied without making real changes to the database. + +4. Move the Table to ``app_b`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Copy the ``TableA`` class definition into ``app_b/tables.py``. + +Ensure the definition matches exactly what it was in ``app_a``. + +5. Create a Migration in ``app_b`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Generate a new fake migration for ``app_b`` to register ``TableA``: + +.. code-block:: bash + + piccolo migrations new app_b --auto + +6. Apply the Migration in ``app_b`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Apply the new migration: + +.. code-block:: bash + + piccolo migrations forwards app_b --fake + +Because the table already exists in the database, Piccolo will associate it with ``app_b`` without duplicating or altering it. + +Notes & Tips +------------ + +- This process preserves your data because it avoids actually dropping or creating the table. +- Always back up your database before doing schema changes. +- Inspect the migration files to understand what Piccolo is tracking. diff --git a/docs/src/piccolo/tutorials/using_sqlite_and_asyncio_effectively.rst b/docs/src/piccolo/tutorials/using_sqlite_and_asyncio_effectively.rst new file mode 100644 index 000000000..67014dd5e --- /dev/null +++ b/docs/src/piccolo/tutorials/using_sqlite_and_asyncio_effectively.rst @@ -0,0 +1,103 @@ +.. _UsingSQLitAndAsyncioEffectively: + +Using SQLite and asyncio effectively +==================================== + +When using Piccolo with SQLite, there are some best practices to follow. + +asyncio => lots of connections +------------------------------ + +With asyncio, we can potentially open lots of database connections, and attempt +to perform concurrent database writes. + +SQLite doesn't support such concurrent behavior as effectively as Postgres, so +we need to be careful. + +One write at a time +~~~~~~~~~~~~~~~~~~~ + +SQLite can easily support lots of transactions concurrently if they are reading, +but only one write can be performed at a time. + +------------------------------------------------------------------------------- + +.. _SQLiteTransactionTypes: + +Transactions +------------ + +SQLite has several transaction types, as specified by Piccolo's +``TransactionType`` enum: + +.. currentmodule:: piccolo.engine.sqlite + +.. autoclass:: TransactionType + :members: + :undoc-members: + +Which to use? +~~~~~~~~~~~~~ + +When creating a transaction, Piccolo uses ``DEFERRED`` by default (to be +consistent with SQLite). + +This means that the first SQL query executed within the transaction determines +whether it's a **READ** or **WRITE**: + +* **READ** - if the first query is a ``SELECT`` +* **WRITE** - if the first query is something like an ``INSERT`` / ``UPDATE`` / ``DELETE`` + +If a transaction starts off with a ``SELECT``, but then tries to perform an ``INSERT`` / ``UPDATE`` / ``DELETE``, +SQLite tries to 'promote' the transaction so it can write. + +The problem is, if multiple concurrent connections try doing this at the same time, +SQLite will return a database locked error. + +So if you're creating a transaction which you know will perform writes, then +create an ``IMMEDIATE`` transaction: + +.. code-block:: python + + from piccolo.engine.sqlite import TransactionType + + async with Band._meta.db.transaction( + transaction_type=TransactionType.immediate + ): + # We perform a SELECT first, but as it's an IMMEDIATE transaction, + # we can later perform writes without getting a database locked + # error. + if not await Band.exists().where(Band.name == 'Pythonistas'): + await Band.objects().create(name="Pythonistas") + +Multiple ``IMMEDIATE`` transactions can exist concurrently - SQLite uses a lock +to make sure only one of them writes at a time. + +If your transaction will just be performing ``SELECT`` queries, then just use +the default ``DEFERRED`` transactions - you will get improved performance, as +no locking is involved: + +.. code-block:: python + + async with Band._meta.db.transaction(): + bands = await Band.select() + managers = await Manager.select() + +------------------------------------------------------------------------------- + +timeout +------- + +It's recommended to specify the ``timeout`` argument in :class:`SQLiteEngine `. + +.. code-block:: python + + DB = SQLiteEngine(timeout=60) + +Imagine you have a web app, and each endpoint creates a transaction which runs +multiple queries. With SQLite, only a single write operation can happen at a +time, so if several connections are open, they may be queued for a while. + +By increasing ``timeout`` it means that queries are less likely to timeout. + +To find out more about ``timeout`` see the Python :func:`sqlite3 docs `. diff --git a/docs/src/piccolo/v1/index.rst b/docs/src/piccolo/v1/index.rst new file mode 100644 index 000000000..0c1a11fa6 --- /dev/null +++ b/docs/src/piccolo/v1/index.rst @@ -0,0 +1,48 @@ +.. _PiccoloV1: + + +About Piccolo v1 +================ + +**20th October** + +Piccolo v1 is now available! + +We migrated to Pydantic v2, and also migrated Piccolo Admin to Vue 3, which +puts the project in a good place moving forward. + +We don't anticipate any major issues for people who are upgrading. If you +encounter any bugs let us know. + +Make sure you have v1 of Piccolo, Piccolo API, and Piccolo Admin. + +**2nd August 2023** + +Piccolo started in August 2018, and as of this writing is close to 5 years old. + +During that time we've had very few, if any, breaking changes. Stability has +always been very important, as we rely on it for our production apps. + +So why release v1 now? We probably should have released v1 several years ago, +but such are things. We now have some unavoidable breaking changes due to one +of our main dependencies (Pydantic) releasing v2. + +In v2, the core of Pydantic has been rewritten in Rust, and has impressive +improvements in performance. Likewise, other libraries in the ecosystem (such +as FastAPI) have moved to Pydantic v2. It only makes sense that Piccolo does it +too. + +In terms of your own code, you shouldn't see much difference. We removed the +``pydantic_config_class`` from ``create_pydantic_model``, and replaced it with +``pydantic_config``, but that's about it. + +However, quite a bit of internal code in Piccolo and its sister libraries +`Piccolo API `_ and +`Piccolo Admin `_ had to be changed to +support Pydantic v2. Supporting both Pydantic v1 and Pydantic v2 would be quite +burdensome. + +So Piccolo v1 will just use Pydantic v2 and above. + +If you can't upgrade to Pydantic v2, then pin your Piccolo version to ``0.118.0``. +You can find the `docs here for 0.118.0 `_. diff --git a/piccolo/__init__.py b/piccolo/__init__.py index b6ea1bd86..00a17cafa 100644 --- a/piccolo/__init__.py +++ b/piccolo/__init__.py @@ -1 +1 @@ -__VERSION__ = "0.33.0" +__VERSION__ = "1.30.0" diff --git a/piccolo/apps/app/commands/new.py b/piccolo/apps/app/commands/new.py index 775da5240..c8beee26a 100644 --- a/piccolo/apps/app/commands/new.py +++ b/piccolo/apps/app/commands/new.py @@ -2,12 +2,16 @@ import importlib import os +import pathlib +import string import sys -import typing as t +from typing import Any -import black # type: ignore +import black import jinja2 +from piccolo.conf.apps import PiccoloConfUpdater + TEMPLATE_DIRECTORY = os.path.join( os.path.dirname(os.path.abspath(__file__)), "templates" ) @@ -30,13 +34,36 @@ def module_exists(module_name: str) -> bool: return True -def new_app(app_name: str, root: str = "."): - print(f"Creating {app_name} app ...") +APP_NAME_ALLOWED_CHARACTERS = [*string.ascii_lowercase, *string.digits, "_"] - app_root = os.path.join(root, app_name) - if os.path.exists(app_root): - sys.exit("Folder already exists - exiting.") +def validate_app_name(app_name: str): + """ + Make sure the app name is something which is a valid Python package name. + + :raises ValueError: + If ``app_name`` isn't valid. + + """ + for char in app_name: + if char.lower() not in APP_NAME_ALLOWED_CHARACTERS: + raise ValueError( + f"The app name contains a disallowed character: `{char}`. " + "It must only include a-z, 0-9, and _ characters." + ) + + +def get_app_module(app_name: str, root: str) -> str: + return ".".join([*pathlib.Path(root).parts, app_name, "piccolo_app"]) + + +def new_app(app_name: str, root: str = ".", register: bool = False): + print(f"Creating {app_name} app ...") + + try: + validate_app_name(app_name=app_name) + except ValueError as exception: + sys.exit(str(exception)) if module_exists(app_name): sys.exit( @@ -44,19 +71,24 @@ def new_app(app_name: str, root: str = "."): "Python module. Please choose a different name for your app." ) - os.mkdir(app_root) + app_root = os.path.join(root, app_name) + + if os.path.exists(app_root): + sys.exit("Folder already exists - exiting.") + + os.makedirs(app_root) with open(os.path.join(app_root, "__init__.py"), "w"): pass - templates: t.Dict[str, t.Any] = { + templates: dict[str, Any] = { "piccolo_app.py": {"app_name": app_name}, "tables.py": {}, } for filename, context in templates.items(): with open(os.path.join(app_root, filename), "w") as f: - template = JINJA_ENV.get_template(filename + ".jinja") + template = JINJA_ENV.get_template(f"{filename}.jinja") file_contents = template.render(**context) file_contents = black.format_str( file_contents, mode=black.FileMode(line_length=80) @@ -69,16 +101,22 @@ def new_app(app_name: str, root: str = "."): with open(os.path.join(migrations_folder_path, "__init__.py"), "w"): pass + if register: + app_module = get_app_module(app_name=app_name, root=root) + PiccoloConfUpdater().register_app(app_module=app_module) + -def new(app_name: str, root: str = "."): +def new(app_name: str, root: str = ".", register: bool = False): """ Creates a new Piccolo app. :param app_name: The name of the new app. :param root: - Where to create the app e.g. /my/folder. By default it creates the + Where to create the app e.g. ./my/folder. By default it creates the app in the current directory. + :param register: + If True, the app is registered automatically in piccolo_conf.py. """ - new_app(app_name=app_name, root=root) + new_app(app_name=app_name, root=root, register=register) diff --git a/piccolo/apps/app/commands/templates/piccolo_app.py.jinja b/piccolo/apps/app/commands/templates/piccolo_app.py.jinja index 9c86cecd7..39d629814 100644 --- a/piccolo/apps/app/commands/templates/piccolo_app.py.jinja +++ b/piccolo/apps/app/commands/templates/piccolo_app.py.jinja @@ -5,7 +5,7 @@ the APP_CONFIG. import os -from piccolo.conf.apps import AppConfig +from piccolo.conf.apps import AppConfig, table_finder, get_package CURRENT_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) @@ -13,8 +13,15 @@ CURRENT_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) APP_CONFIG = AppConfig( app_name='{{ app_name }}', - migrations_folder_path=os.path.join(CURRENT_DIRECTORY, 'piccolo_migrations'), - table_classes=[], + migrations_folder_path=os.path.join( + CURRENT_DIRECTORY, + 'piccolo_migrations' + ), + table_classes=table_finder( + modules=[".tables"], + package=get_package(__name__), + exclude_imported=True + ), migration_dependencies=[], commands=[] ) diff --git a/piccolo/apps/asgi/commands/new.py b/piccolo/apps/asgi/commands/new.py index 9e70346ef..f85b7b980 100644 --- a/piccolo/apps/asgi/commands/new.py +++ b/piccolo/apps/asgi/commands/new.py @@ -2,25 +2,33 @@ import os import shutil -import typing as t -import black # type: ignore -import colorama # type: ignore +import black +import colorama from jinja2 import Environment, FileSystemLoader TEMPLATE_DIR = os.path.join(os.path.dirname(__file__), "templates/app/") -SERVERS = ["uvicorn", "Hypercorn"] -ROUTERS = ["starlette", "fastapi", "blacksheep"] +SERVERS = ["uvicorn", "Hypercorn", "granian"] +ROUTER_DEPENDENCIES = { + "starlette": ["starlette"], + "fastapi": ["fastapi"], + "blacksheep": ["blacksheep[full]"], + "litestar": ["litestar"], + "ravyn": ["ravyn"], + "lilya": ["lilya"], + "quart": ["quart", "quart_schema"], + "falcon": ["falcon"], + "sanic": ["sanic", "sanic_ext"], +} +ROUTERS = list(ROUTER_DEPENDENCIES.keys()) def print_instruction(message: str): print(f"{colorama.Fore.CYAN}{message}{colorama.Fore.RESET}") -def get_options_string(options: t.List[str]): - return ", ".join( - [f"{name} [{index}]" for index, name in enumerate(options)] - ) +def get_options_string(options: list[str]): + return ", ".join(f"{name} [{index}]" for index, name in enumerate(options)) def get_routing_framework() -> str: @@ -49,8 +57,11 @@ def new(root: str = ".", name: str = "piccolo_project"): """ tree = os.walk(TEMPLATE_DIR) + router = get_routing_framework() + template_context = { - "router": get_routing_framework(), + "router": router, + "router_dependencies": ROUTER_DEPENDENCIES.get(router) or [router], "server": get_server(), "project_identifier": name.replace(" ", "_").lower(), } @@ -75,7 +86,7 @@ def new(root: str = ".", name: str = "piccolo_project"): os.mkdir(sub_dir_path) for file_name in file_names: - if file_name.startswith("_"): + if file_name.startswith("_") and file_name != "__init__.py.jinja": continue extension = file_name.rsplit(".")[0] @@ -98,7 +109,7 @@ def new(root: str = ".", name: str = "piccolo_project"): ) except Exception as exception: print(f"Problem processing {output_file_name}") - raise exception + raise exception from exception with open( os.path.join(output_dir_path, output_file_name), "w" diff --git a/piccolo/apps/asgi/commands/templates/app/README.md.jinja b/piccolo/apps/asgi/commands/templates/app/README.md.jinja new file mode 100644 index 000000000..b8c4e06ff --- /dev/null +++ b/piccolo/apps/asgi/commands/templates/app/README.md.jinja @@ -0,0 +1,21 @@ +# {{ project_identifier }} + +## Setup + +### Install requirements + +```bash +pip install -r requirements.txt +``` + +### Getting started guide + +```bash +python main.py +``` + +### Running tests + +```bash +piccolo tester run +``` diff --git a/piccolo/apps/asgi/commands/templates/app/_blacksheep_app.py.jinja b/piccolo/apps/asgi/commands/templates/app/_blacksheep_app.py.jinja index 2d668f595..7873fa47c 100644 --- a/piccolo/apps/asgi/commands/templates/app/_blacksheep_app.py.jinja +++ b/piccolo/apps/asgi/commands/templates/app/_blacksheep_app.py.jinja @@ -1,20 +1,17 @@ -import typing as t - -from piccolo_admin.endpoints import create_admin -from piccolo_api.crud.serializers import create_pydantic_model -from piccolo.engine import engine_finder +from typing import Any +from blacksheep.exceptions import HTTPException from blacksheep.server import Application from blacksheep.server.bindings import FromJSON -from blacksheep.server.responses import json -from blacksheep.server.openapi.v3 import OpenAPIHandler -from openapidocs.v3 import Info +from blacksheep.server.openapi.v3 import OpenAPIHandler, Info +from piccolo.engine import engine_finder +from piccolo_admin.endpoints import create_admin +from piccolo_api.crud.serializers import create_pydantic_model from home.endpoints import home -from home.piccolo_app import APP_CONFIG +from home.piccolo_app import APP_CONFIG from home.tables import Task - app = Application() app.mount( @@ -36,69 +33,96 @@ app.serve_files("static", root_path="/static") app.router.add_get("/", home) -TaskModelIn = create_pydantic_model(table=Task, model_name="TaskModelIn") -TaskModelOut = create_pydantic_model( - table=Task, include_default_columns=True, model_name="TaskModelOut" +TaskModelIn: Any = create_pydantic_model( + table=Task, + model_name="TaskModelIn", ) -TaskModelPartial = create_pydantic_model( - table=Task, model_name="TaskModelPartial", all_optional=True +TaskModelOut: Any = create_pydantic_model( + table=Task, + include_default_columns=True, + model_name="TaskModelOut", +) +TaskModelPartial: Any = ( + create_pydantic_model( + table=Task, + model_name="TaskModelPartial", + all_optional=True, + ), ) +# Check if the record is None. Use for query callback +def check_record_not_found(result: dict[str, Any]) -> dict[str, Any]: + if result is None: + raise HTTPException(status=404) + return result + + @app.router.get("/tasks/") -async def tasks() -> t.List[TaskModelOut]: - return await Task.select().order_by(Task.id).run() +async def tasks() -> list[TaskModelOut]: + tasks = await Task.select().order_by(Task._meta.primary_key, ascending=False) + return [TaskModelOut(**task) for task in tasks] + + +@app.router.get("/tasks/{task_id}/") +async def single_task(task_id: int) -> TaskModelOut: + task = ( + await Task.select() + .where(Task._meta.primary_key == task_id) + .first() + .callback(check_record_not_found) + ) + return TaskModelOut(**task) @app.router.post("/tasks/") -async def create_task(task: FromJSON[TaskModelIn]) -> TaskModelOut: - task = Task(**task.value.__dict__) - await task.save().run() - return TaskModelOut(**task.__dict__) +async def create_task(task_model: FromJSON[TaskModelIn]) -> TaskModelOut: + task = Task(**task_model.value.model_dump()) + await task.save() + return TaskModelOut(**task.to_dict()) @app.router.put("/tasks/{task_id}/") -async def put_task( - task_id: int, task: FromJSON[TaskModelIn] -) -> TaskModelOut: - _task = await Task.objects().where(Task.id == task_id).first().run() - if not _task: - return json({}, status=404) +async def put_task(task_id: int, task_model: FromJSON[TaskModelIn]) -> TaskModelOut: + task = ( + await Task.objects() + .get(Task._meta.primary_key == task_id) + .callback(check_record_not_found) + ) - for key, value in task.value.__dict__.items(): - setattr(_task, key, value) + for key, value in task_model.value.model_dump().items(): + setattr(task, key, value) - await _task.save().run() - - return TaskModelOut(**_task.__dict__) + await task.save() + return TaskModelOut(**task.to_dict()) @app.router.patch("/tasks/{task_id}/") async def patch_task( - task_id: int, task: FromJSON[TaskModelPartial] + task_id: int, task_model: FromJSON[TaskModelPartial] ) -> TaskModelOut: - _task = await Task.objects().where(Task.id == task_id).first().run() - if not _task: - return json({}, status=404) - - for key, value in task.value.__dict__.items(): - if value is not None: - setattr(_task, key, value) + task = ( + await Task.objects() + .get(Task._meta.primary_key == task_id) + .callback(check_record_not_found) + ) - await _task.save().run() + for key, value in task_model.value.model_dump().items(): + if value is not None: + setattr(task, key, value) - return TaskModelOut(**_task.__dict__) + await task.save() + return TaskModelOut(**task.to_dict()) @app.router.delete("/tasks/{task_id}/") -async def delete_task(task_id: int): - task = await Task.objects().where(Task.id == task_id).first().run() - if not task: - return json({}, status=404) - - await task.remove().run() - - return json({}) +async def delete_task(task_id: int) -> None: + task = ( + await Task.objects() + .get(Task._meta.primary_key == task_id) + .callback(check_record_not_found) + ) + await task.remove() async def open_database_connection_pool(application): @@ -119,3 +143,5 @@ async def close_database_connection_pool(application): app.on_start += open_database_connection_pool app.on_stop += close_database_connection_pool + +app.router.apply_routes() diff --git a/piccolo/apps/asgi/commands/templates/app/_falcon_app.py.jinja b/piccolo/apps/asgi/commands/templates/app/_falcon_app.py.jinja new file mode 100644 index 000000000..1c5918999 --- /dev/null +++ b/piccolo/apps/asgi/commands/templates/app/_falcon_app.py.jinja @@ -0,0 +1,115 @@ +import os +from typing import Any + +import falcon.asgi +from hypercorn.middleware import DispatcherMiddleware +from piccolo.engine import engine_finder +from piccolo_admin.endpoints import create_admin + +from home.endpoints import HomeEndpoint +from home.piccolo_app import APP_CONFIG +from home.tables import Task + + +async def open_database_connection_pool(): + try: + engine = engine_finder() + await engine.start_connection_pool() + except Exception: + print("Unable to connect to the database") + + +async def close_database_connection_pool(): + try: + engine = engine_finder() + await engine.close_connection_pool() + except Exception: + print("Unable to connect to the database") + + +# Check if the record is None. Use for query callback +def check_record_not_found(result: dict[str, Any]) -> dict[str, Any]: + if result is None: + raise falcon.HTTPNotFound() + return result + + +class LifespanMiddleware: + async def process_startup( + self, scope: dict[str, Any], event: dict[str, Any] + ) -> None: + await open_database_connection_pool() + + async def process_shutdown( + self, scope: dict[str, Any], event: dict[str, Any] + ) -> None: + await close_database_connection_pool() + + +class TaskCollectionResource: + async def on_get(self, req, resp): + tasks = await Task.select().order_by(Task._meta.primary_key, ascending=False) + resp.media = tasks + + async def on_post(self, req, resp): + data = await req.media + task = Task(**data) + await task.save() + resp.status = falcon.HTTP_201 + resp.media = task.to_dict() + + +class TaskItemResource: + async def on_get(self, req, resp, task_id): + task = ( + await Task.select() + .where(Task._meta.primary_key == task_id) + .first() + .callback(check_record_not_found) + ) + resp.status = falcon.HTTP_200 + resp.media = task + + async def on_put(self, req, resp, task_id): + task = ( + await Task.objects() + .get(Task._meta.primary_key == task_id) + .callback(check_record_not_found) + ) + + data = await req.media + for key, value in data.items(): + setattr(task, key, value) + + await task.save() + resp.status = falcon.HTTP_200 + resp.media = task.to_dict() + + async def on_delete(self, req, resp, task_id): + task = ( + await Task.objects() + .get(Task._meta.primary_key == task_id) + .callback(check_record_not_found) + ) + resp.status = falcon.HTTP_204 + await task.remove() + + +app: Any = falcon.asgi.App(middleware=LifespanMiddleware()) +app.add_static_route("/static", directory=os.path.abspath("static")) +app.add_route("/", HomeEndpoint()) +app.add_route("/tasks/", TaskCollectionResource()) +app.add_route("/tasks/{task_id:int}", TaskItemResource()) + + +# enable the admin application using DispatcherMiddleware +app = DispatcherMiddleware( # type: ignore + { + "/admin": create_admin( + tables=APP_CONFIG.table_classes, + # Required when running under HTTPS: + # allowed_hosts=['my_site.com'] + ), + "": app, + } +) diff --git a/piccolo/apps/asgi/commands/templates/app/_fastapi_app.py.jinja b/piccolo/apps/asgi/commands/templates/app/_fastapi_app.py.jinja index 4053c7cde..7a096528a 100644 --- a/piccolo/apps/asgi/commands/templates/app/_fastapi_app.py.jinja +++ b/piccolo/apps/asgi/commands/templates/app/_fastapi_app.py.jinja @@ -1,11 +1,12 @@ -import typing as t +from contextlib import asynccontextmanager +from typing import Any +from fastapi import FastAPI, status +from fastapi.exceptions import HTTPException +from piccolo.engine import engine_finder from piccolo_admin.endpoints import create_admin from piccolo_api.crud.serializers import create_pydantic_model -from piccolo.engine import engine_finder -from starlette.routing import Route, Mount -from fastapi import FastAPI -from fastapi.responses import JSONResponse +from starlette.routing import Mount, Route from starlette.staticfiles import StaticFiles from home.endpoints import HomeEndpoint @@ -13,6 +14,29 @@ from home.piccolo_app import APP_CONFIG from home.tables import Task +async def open_database_connection_pool(): + try: + engine = engine_finder() + await engine.start_connection_pool() + except Exception: + print("Unable to connect to the database") + + +async def close_database_connection_pool(): + try: + engine = engine_finder() + await engine.close_connection_pool() + except Exception: + print("Unable to connect to the database") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + await open_database_connection_pool() + yield + await close_database_connection_pool() + + app = FastAPI( routes=[ Route("/", HomeEndpoint), @@ -22,71 +46,82 @@ app = FastAPI( tables=APP_CONFIG.table_classes, # Required when running under HTTPS: # allowed_hosts=['my_site.com'] - ) + ), ), Mount("/static/", StaticFiles(directory="static")), ], + lifespan=lifespan, ) -TaskModelIn = create_pydantic_model(table=Task, model_name='TaskModelIn') -TaskModelOut = create_pydantic_model( +TaskModelIn: Any = create_pydantic_model( table=Task, - include_default_columns=True, - model_name='TaskModelOut' + model_name="TaskModelIn", ) - -@app.get("/tasks/", response_model=t.List[TaskModelOut]) -async def tasks(): - return await Task.select().order_by(Task.id).run() - - -@app.post('/tasks/', response_model=TaskModelOut) -async def create_task(task: TaskModelIn): - task = Task(**task.__dict__) - await task.save().run() - return TaskModelOut(**task.__dict__) +TaskModelOut: Any = create_pydantic_model( + table=Task, + include_default_columns=True, + model_name="TaskModelOut", +) -@app.put('/tasks/{task_id}/', response_model=TaskModelOut) -async def update_task(task_id: int, task: TaskModelIn): - _task = await Task.objects().where(Task.id == task_id).first().run() - if not _task: - return JSONResponse({}, status_code=404) +# Check if the record is None. Use for query callback +def check_record_not_found(result: dict[str, Any]) -> dict[str, Any]: + if result is None: + raise HTTPException( + detail="Record not found", + status_code=status.HTTP_404_NOT_FOUND, + ) + return result - for key, value in task.__dict__.items(): - setattr(_task, key, value) - await _task.save().run() +@app.get("/tasks/", response_model=list[TaskModelOut], tags=["Task"]) +async def tasks(): + return await Task.select().order_by(Task._meta.primary_key, ascending=False) - return TaskModelOut(**_task.__dict__) +@app.get("/tasks/{task_id}/", response_model=TaskModelOut, tags=["Task"]) +async def single_task(task_id: int): + task = ( + await Task.select() + .where(Task._meta.primary_key == task_id) + .first() + .callback(check_record_not_found) + ) + return task -@app.delete('/tasks/{task_id}/') -async def delete_task(task_id: int): - task = await Task.objects().where(Task.id == task_id).first().run() - if not task: - return JSONResponse({}, status_code=404) - await task.remove().run() +@app.post("/tasks/", response_model=TaskModelOut, tags=["Task"]) +async def create_task(task_model: TaskModelIn): + task = Task(**task_model.model_dump()) + await task.save() + return task.to_dict() - return JSONResponse({}) +@app.put("/tasks/{task_id}/", response_model=TaskModelOut, tags=["Task"]) +async def update_task(task_id: int, task_model: TaskModelIn): + task = ( + await Task.objects() + .get(Task._meta.primary_key == task_id) + .callback(check_record_not_found) + ) + for key, value in task_model.model_dump().items(): + setattr(task, key, value) -@app.on_event("startup") -async def open_database_connection_pool(): - try: - engine = engine_finder() - await engine.start_connection_pool() - except Exception: - print("Unable to connect to the database") + await task.save() + return task.to_dict() -@app.on_event("shutdown") -async def close_database_connection_pool(): - try: - engine = engine_finder() - await engine.close_connection_pool() - except Exception: - print("Unable to connect to the database") +@app.delete( + "/tasks/{task_id}/", + status_code=status.HTTP_204_NO_CONTENT, + tags=["Task"], +) +async def delete_task(task_id: int): + task = ( + await Task.objects() + .get(Task._meta.primary_key == task_id) + .callback(check_record_not_found) + ) + await task.remove() diff --git a/piccolo/apps/asgi/commands/templates/app/_lilya_app.py.jinja b/piccolo/apps/asgi/commands/templates/app/_lilya_app.py.jinja new file mode 100644 index 000000000..30ce3639e --- /dev/null +++ b/piccolo/apps/asgi/commands/templates/app/_lilya_app.py.jinja @@ -0,0 +1,44 @@ +from lilya.apps import Lilya +from lilya.routing import Include, Path +from lilya.staticfiles import StaticFiles +from piccolo.engine import engine_finder +from piccolo_admin.endpoints import create_admin +from piccolo_api.crud.endpoints import PiccoloCRUD + +from home.endpoints import HomeController +from home.piccolo_app import APP_CONFIG +from home.tables import Task + +app = Lilya( + routes=[ + Path("/", HomeController), + Include( + "/admin/", + create_admin( + tables=APP_CONFIG.table_classes, + # Required when running under HTTPS: + # allowed_hosts=['my_site.com'] + ), + ), + Include("/static/", StaticFiles(directory="static")), + Include("/tasks/", PiccoloCRUD(table=Task)), + ], +) + + +@app.on_event("on_startup") +async def open_database_connection_pool(): + try: + engine = engine_finder() + await engine.start_connection_pool() + except Exception: + print("Unable to connect to the database") + + +@app.on_event("on_shutdown") +async def close_database_connection_pool(): + try: + engine = engine_finder() + await engine.close_connection_pool() + except Exception: + print("Unable to connect to the database") diff --git a/piccolo/apps/asgi/commands/templates/app/_litestar_app.py.jinja b/piccolo/apps/asgi/commands/templates/app/_litestar_app.py.jinja new file mode 100644 index 000000000..963ea3c42 --- /dev/null +++ b/piccolo/apps/asgi/commands/templates/app/_litestar_app.py.jinja @@ -0,0 +1,143 @@ +from typing import Any + +from litestar import Litestar, asgi, delete, get, patch, post +from litestar.contrib.jinja import JinjaTemplateEngine +from litestar.exceptions import NotFoundException +from litestar.static_files import StaticFilesConfig +from litestar.template import TemplateConfig +from litestar.types import Receive, Scope, Send +from piccolo.engine import engine_finder +from piccolo_admin.endpoints import create_admin +from pydantic import BaseModel + +from home.endpoints import home +from home.piccolo_app import APP_CONFIG +from home.tables import Task + +""" +NOTE: `create_pydantic_model` is not compatible with Litestar +version higher than 2.11.0. If you are using Litestar<=2.11.0, +you can use `create_pydantic_model` as in other asgi templates + +from piccolo.utils.pydantic import create_pydantic_model + +TaskModelIn: Any = create_pydantic_model( + table=Task, + model_name="TaskModelIn", +) +TaskModelOut: Any = create_pydantic_model( + table=Task, + include_default_columns=True, + model_name="TaskModelOut", +) +""" + + +class TaskModelIn(BaseModel): + name: str + completed: bool = False + + +class TaskModelOut(BaseModel): + id: int + name: str + completed: bool = False + + +# mounting Piccolo Admin +@asgi("/admin/", is_mount=True) +async def admin(scope: "Scope", receive: "Receive", send: "Send") -> None: + await create_admin(tables=APP_CONFIG.table_classes)(scope, receive, send) + + +# Check if the record is None. Use for query callback +def check_record_not_found(result: dict[str, Any]) -> dict[str, Any]: + if result is None: + raise NotFoundException(detail="Record not found") + return result + + +@get("/tasks", tags=["Task"]) +async def tasks() -> list[TaskModelOut]: + tasks = await Task.select().order_by(Task._meta.primary_key, ascending=False) + return [TaskModelOut(**task) for task in tasks] + + +@get("/tasks/{task_id:int}", tags=["Task"]) +async def single_task(task_id: int) -> TaskModelOut: + task = ( + await Task.select() + .where(Task._meta.primary_key == task_id) + .first() + .callback(check_record_not_found) + ) + return TaskModelOut(**task) + + +@post("/tasks", tags=["Task"]) +async def create_task(data: TaskModelIn) -> TaskModelOut: + task = Task(**data.model_dump()) + await task.save() + return TaskModelOut(**task.to_dict()) + + +@patch("/tasks/{task_id:int}", tags=["Task"]) +async def update_task(task_id: int, data: TaskModelIn) -> TaskModelOut: + task = ( + await Task.objects() + .get(Task._meta.primary_key == task_id) + .callback(check_record_not_found) + ) + + for key, value in data.model_dump().items(): + setattr(task, key, value) + + await task.save() + return TaskModelOut(**task.to_dict()) + + +@delete("/tasks/{task_id:int}", tags=["Task"]) +async def delete_task(task_id: int) -> None: + task = ( + await Task.objects() + .get(Task._meta.primary_key == task_id) + .callback(check_record_not_found) + ) + await task.remove() + + +async def open_database_connection_pool(): + try: + engine = engine_finder() + await engine.start_connection_pool() + except Exception: + print("Unable to connect to the database") + + +async def close_database_connection_pool(): + try: + engine = engine_finder() + await engine.close_connection_pool() + except Exception: + print("Unable to connect to the database") + + +app = Litestar( + route_handlers=[ + admin, + home, + tasks, + single_task, + create_task, + update_task, + delete_task, + ], + template_config=TemplateConfig( + directory="home/templates", engine=JinjaTemplateEngine + ), + static_files_config=[ + StaticFilesConfig(directories=["static"], path="/static/"), + ], + on_startup=[open_database_connection_pool], + on_shutdown=[close_database_connection_pool], +) diff --git a/piccolo/apps/asgi/commands/templates/app/_quart_app.py.jinja b/piccolo/apps/asgi/commands/templates/app/_quart_app.py.jinja new file mode 100644 index 000000000..cf648469a --- /dev/null +++ b/piccolo/apps/asgi/commands/templates/app/_quart_app.py.jinja @@ -0,0 +1,142 @@ +from typing import Any +from http import HTTPStatus + +from hypercorn.middleware import DispatcherMiddleware +from piccolo.engine import engine_finder +from piccolo_admin.endpoints import create_admin +from piccolo_api.crud.serializers import create_pydantic_model +from quart import Quart +from quart.helpers import abort +from quart_schema import ( + Info, + QuartSchema, + hide, + tag, + validate_request, + validate_response, +) + +from home.endpoints import index +from home.piccolo_app import APP_CONFIG +from home.tables import Task + + +app = Quart(__name__, static_folder="static") +QuartSchema(app, info=Info(title="Quart API", version="0.1.0")) + + +TaskModelIn: Any = create_pydantic_model( + table=Task, + model_name="TaskModelIn", +) +TaskModelOut: Any = create_pydantic_model( + table=Task, + include_default_columns=True, + model_name="TaskModelOut", +) + + +# Check if the record is None. Use for query callback +def check_record_not_found(result: dict[str, Any]) -> dict[str, Any]: + if result is None: + abort(code=HTTPStatus.NOT_FOUND) + return result + + +@app.get("/") +@hide +def home(): + return index() + + +@app.get("/tasks/") +@validate_response(list[TaskModelOut]) +@tag(["Task"]) +async def tasks(): + tasks = await Task.select().order_by(Task._meta.primary_key, ascending=False) + return [TaskModelOut(**task) for task in tasks], HTTPStatus.OK + + +@app.get("/tasks//") +@validate_response(TaskModelOut) +@tag(["Task"]) +async def single_task(task_id: int): + task = ( + await Task.select() + .where(Task._meta.primary_key == task_id) + .first() + .callback(check_record_not_found) + ) + return TaskModelOut(**task), HTTPStatus.OK + + +@app.post("/tasks/") +@validate_request(TaskModelIn) +@validate_response(TaskModelOut) +@tag(["Task"]) +async def create_task(data: TaskModelIn): + task = Task(**data.model_dump()) + await task.save() + return task.to_dict(), HTTPStatus.CREATED + + +@app.put("/tasks//") +@validate_request(TaskModelIn) +@validate_response(TaskModelOut) +@tag(["Task"]) +async def update_task(task_id: int, data: TaskModelIn): + task = ( + await Task.objects() + .get(Task._meta.primary_key == task_id) + .callback(check_record_not_found) + ) + + for key, value in data.model_dump().items(): + setattr(task, key, value) + + await task.save() + + return TaskModelOut(**task.to_dict()), HTTPStatus.OK + + +@app.delete("/tasks//") +@tag(["Task"]) +async def delete_task(task_id: int): + task = ( + await Task.objects() + .get(Task._meta.primary_key == task_id) + .callback(check_record_not_found) + ) + await task.remove() + return {}, HTTPStatus.OK + + +@app.before_serving +async def open_database_connection_pool(): + try: + engine = engine_finder() + await engine.start_connection_pool() + except Exception: + print("Unable to connect to the database") + + +@app.after_serving +async def close_database_connection_pool(): + try: + engine = engine_finder() + await engine.close_connection_pool() + except Exception: + print("Unable to connect to the database") + + +# enable the admin application using DispatcherMiddleware +app = DispatcherMiddleware( # type: ignore + { + "/admin": create_admin( + tables=APP_CONFIG.table_classes, + # Required when running under HTTPS: + # allowed_hosts=['my_site.com'] + ), + "": app, + } +) diff --git a/piccolo/apps/asgi/commands/templates/app/_ravyn_app.py.jinja b/piccolo/apps/asgi/commands/templates/app/_ravyn_app.py.jinja new file mode 100644 index 000000000..ad9cc3407 --- /dev/null +++ b/piccolo/apps/asgi/commands/templates/app/_ravyn_app.py.jinja @@ -0,0 +1,126 @@ +from typing import Any +from pathlib import Path + +from ravyn import ( + APIView, + Ravyn, + Gateway, + HTTPException, + Include, + delete, + get, + post, + put, +) +from ravyn.core.config import StaticFilesConfig +from piccolo.engine import engine_finder +from piccolo.utils.pydantic import create_pydantic_model +from piccolo_admin.endpoints import create_admin + +from home.endpoints import home +from home.piccolo_app import APP_CONFIG +from home.tables import Task + + +async def open_database_connection_pool(): + try: + engine = engine_finder() + await engine.start_connection_pool() + except Exception: + print("Unable to connect to the database") + + +async def close_database_connection_pool(): + try: + engine = engine_finder() + await engine.close_connection_pool() + except Exception: + print("Unable to connect to the database") + + +TaskModelIn: Any = create_pydantic_model( + table=Task, + model_name="TaskModelIn", +) +TaskModelOut: Any = create_pydantic_model( + table=Task, + include_default_columns=True, + model_name="TaskModelOut", +) + + +# Check if the record is None. Use for query callback +def check_record_not_found(result: dict[str, Any]) -> dict[str, Any]: + if result is None: + raise HTTPException( + detail="Record not found", + status_code=404, + ) + return result + + +class TaskAPIView(APIView): + path: str = "/" + tags: list[str] = ["Task"] + + @get("/") + async def tasks(self) -> list[TaskModelOut]: + tasks = await Task.select().order_by(Task._meta.primary_key, ascending=False) + return [TaskModelOut(**task) for task in tasks] + + @get("/{task_id}") + async def single_task(self, task_id: int) -> TaskModelOut: + task = ( + await Task.select() + .where(Task._meta.primary_key == task_id) + .first() + .callback(check_record_not_found) + ) + return TaskModelOut(**task) + + @post("/") + async def create_task(self, payload: TaskModelIn) -> TaskModelOut: + task = Task(**payload.model_dump()) + await task.save() + return TaskModelOut(**task.to_dict()) + + @put("/{task_id}") + async def update_task(self, payload: TaskModelIn, task_id: int) -> TaskModelOut: + task = ( + await Task.objects() + .get(Task._meta.primary_key == task_id) + .callback(check_record_not_found) + ) + for key, value in payload.model_dump().items(): + setattr(task, key, value) + + await task.save() + return TaskModelOut(**task.to_dict()) + + @delete("/{task_id}") + async def delete_task(self, task_id: int) -> None: + task = ( + await Task.objects() + .get(Task._meta.primary_key == task_id) + .callback(check_record_not_found) + ) + await task.remove() + + +app = Ravyn( + routes=[ + Gateway("/", handler=home), + Gateway("/tasks", handler=TaskAPIView), + Include( + "/admin/", + create_admin( + tables=APP_CONFIG.table_classes, + # Required when running under HTTPS: + # allowed_hosts=['my_site.com'] + ), + ), + ], + static_files_config=StaticFilesConfig(path="/static", directory=Path("static")), + on_startup=[open_database_connection_pool], + on_shutdown=[close_database_connection_pool], +) diff --git a/piccolo/apps/asgi/commands/templates/app/_sanic_app.py.jinja b/piccolo/apps/asgi/commands/templates/app/_sanic_app.py.jinja new file mode 100644 index 000000000..de5836d4c --- /dev/null +++ b/piccolo/apps/asgi/commands/templates/app/_sanic_app.py.jinja @@ -0,0 +1,144 @@ +from typing import Any + +from hypercorn.middleware import DispatcherMiddleware +from piccolo.engine import engine_finder +from piccolo_admin.endpoints import create_admin +from piccolo_api.crud.serializers import create_pydantic_model +from sanic import NotFound, Request, Sanic, json +from sanic_ext import openapi + +from home.endpoints import index +from home.piccolo_app import APP_CONFIG +from home.tables import Task + +app = Sanic(__name__) +app.static("/static/", "static") + + +TaskModelIn: Any = create_pydantic_model( + table=Task, + model_name="TaskModelIn", +) +TaskModelOut: Any = create_pydantic_model( + table=Task, + include_default_columns=True, + model_name="TaskModelOut", +) + + +# Check if the record is None. Use for query callback +def check_record_not_found(result: dict[str, Any]) -> dict[str, Any]: + if result is None: + raise NotFound(message="Record not found") + return result + + +@app.get("/") +@openapi.exclude() +def home(request: Request): + return index() + + +@app.get("/tasks/") +@openapi.tag("Task") +@openapi.response(200, {"application/json": TaskModelOut.model_json_schema()}) +async def tasks(request: Request): + return json( + await Task.select().order_by(Task._meta.primary_key, ascending=False), + status=200, + ) + + +@app.get("/tasks//") +@openapi.tag("Task") +@openapi.response(200, {"application/json": TaskModelOut.model_json_schema()}) +async def single_task(request: Request, task_id: int): + task = ( + await Task.select() + .where(Task._meta.primary_key == task_id) + .first() + .callback(check_record_not_found) + ) + return json(task, status=200) + + +@app.post("/tasks/") +@openapi.definition( + body={"application/json": TaskModelIn.model_json_schema()}, + tag="Task", +) +@openapi.response(201, {"application/json": TaskModelOut.model_json_schema()}) +async def create_task(request: Request): + task = Task(**request.json) + await task.save() + return json(task.to_dict(), status=201) + + +@app.put("/tasks//") +@openapi.definition( + body={"application/json": TaskModelIn.model_json_schema()}, + tag="Task", +) +@openapi.response(200, {"application/json": TaskModelOut.model_json_schema()}) +async def update_task(request: Request, task_id: int): + task = ( + await Task.objects() + .get(Task._meta.primary_key == task_id) + .callback(check_record_not_found) + ) + for key, value in request.json.items(): + setattr(task, key, value) + + await task.save() + return json(task.to_dict(), status=200) + + +@app.delete("/tasks//") +@openapi.tag("Task") +async def delete_task(request: Request, task_id: int): + task = ( + await Task.objects() + .get(Task._meta.primary_key == task_id) + .callback(check_record_not_found) + ) + await task.remove() + return json({}, status=200) + + +async def open_database_connection_pool(): + try: + engine = engine_finder() + await engine.start_connection_pool() + except Exception: + print("Unable to connect to the database") + + +async def close_database_connection_pool(): + try: + engine = engine_finder() + await engine.close_connection_pool() + except Exception: + print("Unable to connect to the database") + + +@app.after_server_start +async def startup(app, loop): + await open_database_connection_pool() + + +@app.before_server_stop +async def shutdown(app, loop): + await close_database_connection_pool() + + +# enable the admin application using DispatcherMiddleware +app = DispatcherMiddleware( # type: ignore + { + "/admin": create_admin( + tables=APP_CONFIG.table_classes, + # Required when running under HTTPS: + # allowed_hosts=['my_site.com'] + ), + "": app, + } +) diff --git a/piccolo/apps/asgi/commands/templates/app/_starlette_app.py.jinja b/piccolo/apps/asgi/commands/templates/app/_starlette_app.py.jinja index 1c401bf43..de99c8e15 100644 --- a/piccolo/apps/asgi/commands/templates/app/_starlette_app.py.jinja +++ b/piccolo/apps/asgi/commands/templates/app/_starlette_app.py.jinja @@ -1,8 +1,10 @@ +from contextlib import asynccontextmanager + +from piccolo.engine import engine_finder from piccolo_admin.endpoints import create_admin from piccolo_api.crud.endpoints import PiccoloCRUD -from piccolo.engine import engine_finder -from starlette.routing import Route, Mount from starlette.applications import Starlette +from starlette.routing import Mount, Route from starlette.staticfiles import StaticFiles from home.endpoints import HomeEndpoint @@ -10,24 +12,6 @@ from home.piccolo_app import APP_CONFIG from home.tables import Task -app = Starlette( - routes=[ - Route("/", HomeEndpoint), - Mount( - "/admin/", - create_admin( - tables=APP_CONFIG.table_classes, - # Required when running under HTTPS: - # allowed_hosts=['my_site.com'] - ) - ), - Mount("/static/", StaticFiles(directory="static")), - Mount("/tasks/", PiccoloCRUD(table=Task)) - ], -) - - -@app.on_event("startup") async def open_database_connection_pool(): try: engine = engine_finder() @@ -36,10 +20,34 @@ async def open_database_connection_pool(): print("Unable to connect to the database") -@app.on_event("shutdown") async def close_database_connection_pool(): try: engine = engine_finder() await engine.close_connection_pool() except Exception: print("Unable to connect to the database") + + +@asynccontextmanager +async def lifespan(app: Starlette): + await open_database_connection_pool() + yield + await close_database_connection_pool() + + +app = Starlette( + routes=[ + Route("/", HomeEndpoint), + Mount( + "/admin/", + create_admin( + tables=APP_CONFIG.table_classes, + # Required when running under HTTPS: + # allowed_hosts=['my_site.com'] + ), + ), + Mount("/static/", StaticFiles(directory="static")), + Mount("/tasks/", PiccoloCRUD(table=Task)), + ], + lifespan=lifespan, +) diff --git a/piccolo/apps/asgi/commands/templates/app/app.py.jinja b/piccolo/apps/asgi/commands/templates/app/app.py.jinja index c21b3e029..9e22c8f28 100644 --- a/piccolo/apps/asgi/commands/templates/app/app.py.jinja +++ b/piccolo/apps/asgi/commands/templates/app/app.py.jinja @@ -4,4 +4,16 @@ {% include '_starlette_app.py.jinja' %} {% elif router == 'blacksheep' %} {% include '_blacksheep_app.py.jinja' %} +{% elif router == 'litestar' %} + {% include '_litestar_app.py.jinja' %} +{% elif router == 'ravyn' %} + {% include '_ravyn_app.py.jinja' %} +{% elif router == 'lilya' %} + {% include '_lilya_app.py.jinja' %} +{% elif router == 'quart' %} + {% include '_quart_app.py.jinja' %} +{% elif router == 'falcon' %} + {% include '_falcon_app.py.jinja' %} +{% elif router == 'sanic' %} + {% include '_sanic_app.py.jinja' %} {% endif %} diff --git a/piccolo/apps/asgi/commands/templates/app/conftest.py.jinja b/piccolo/apps/asgi/commands/templates/app/conftest.py.jinja new file mode 100644 index 000000000..70e2d5584 --- /dev/null +++ b/piccolo/apps/asgi/commands/templates/app/conftest.py.jinja @@ -0,0 +1,17 @@ +import os +import sys + +from piccolo.utils.warnings import colored_warning + + +def pytest_configure(*args): + if os.environ.get("PICCOLO_TEST_RUNNER") != "True": + colored_warning( + "\n\n" + "We recommend running Piccolo tests using the " + "`piccolo tester run` command, which wraps Pytest, and makes " + "sure the test database is being used. " + "To stop this warning, modify conftest.py." + "\n\n" + ) + sys.exit(1) diff --git a/docs/src/_static/.gitkeep b/piccolo/apps/asgi/commands/templates/app/home/__init__.py.jinja similarity index 100% rename from docs/src/_static/.gitkeep rename to piccolo/apps/asgi/commands/templates/app/home/__init__.py.jinja diff --git a/piccolo/apps/asgi/commands/templates/app/home/_blacksheep_endpoints.py.jinja b/piccolo/apps/asgi/commands/templates/app/home/_blacksheep_endpoints.py.jinja index c303b7061..daf650cb6 100644 --- a/piccolo/apps/asgi/commands/templates/app/home/_blacksheep_endpoints.py.jinja +++ b/piccolo/apps/asgi/commands/templates/app/home/_blacksheep_endpoints.py.jinja @@ -12,7 +12,7 @@ ENVIRONMENT = jinja2.Environment( ) -def home(): +async def home(request): template = ENVIRONMENT.get_template("home.html.jinja") content = template.render(title="Piccolo + ASGI",) return Response( diff --git a/piccolo/apps/asgi/commands/templates/app/home/_falcon_endpoints.py.jinja b/piccolo/apps/asgi/commands/templates/app/home/_falcon_endpoints.py.jinja new file mode 100644 index 000000000..5fcbde41e --- /dev/null +++ b/piccolo/apps/asgi/commands/templates/app/home/_falcon_endpoints.py.jinja @@ -0,0 +1,19 @@ +import os + +import falcon +import jinja2 + +ENVIRONMENT = jinja2.Environment( + loader=jinja2.FileSystemLoader( + searchpath=os.path.join(os.path.dirname(__file__), "templates") + ) +) + + +class HomeEndpoint: + async def on_get(self, req, resp): + template = ENVIRONMENT.get_template("home.html.jinja") + content = template.render(title="Piccolo + ASGI",) + resp.status = falcon.HTTP_200 + resp.content_type = "text/html" + resp.text = content diff --git a/piccolo/apps/asgi/commands/templates/app/home/_lilya_endpoints.py.jinja b/piccolo/apps/asgi/commands/templates/app/home/_lilya_endpoints.py.jinja new file mode 100644 index 000000000..1b3328fed --- /dev/null +++ b/piccolo/apps/asgi/commands/templates/app/home/_lilya_endpoints.py.jinja @@ -0,0 +1,21 @@ +import os + +import jinja2 +from lilya.controllers import Controller +from lilya.responses import HTMLResponse + + +ENVIRONMENT = jinja2.Environment( + loader=jinja2.FileSystemLoader( + searchpath=os.path.join(os.path.dirname(__file__), "templates") + ) +) + + +class HomeController(Controller): + async def get(self, request): + template = ENVIRONMENT.get_template("home.html.jinja") + + content = template.render(title="Piccolo + ASGI",) + + return HTMLResponse(content) diff --git a/piccolo/apps/asgi/commands/templates/app/home/_litestar_endpoints.py.jinja b/piccolo/apps/asgi/commands/templates/app/home/_litestar_endpoints.py.jinja new file mode 100644 index 000000000..e5dfc0661 --- /dev/null +++ b/piccolo/apps/asgi/commands/templates/app/home/_litestar_endpoints.py.jinja @@ -0,0 +1,22 @@ +import os + +import jinja2 +from litestar import MediaType, Request, Response, get + +ENVIRONMENT = jinja2.Environment( + loader=jinja2.FileSystemLoader( + searchpath=os.path.join(os.path.dirname(__file__), "templates") + ) +) + + +@get(path="/", include_in_schema=False, sync_to_thread=False) +def home(request: Request) -> Response: + template = ENVIRONMENT.get_template("home.html.jinja") + content = template.render(title="Piccolo + ASGI") + return Response( + content, + media_type=MediaType.HTML, + status_code=200, + ) + diff --git a/piccolo/apps/asgi/commands/templates/app/home/_quart_endpoints.py.jinja b/piccolo/apps/asgi/commands/templates/app/home/_quart_endpoints.py.jinja new file mode 100644 index 000000000..977ebd211 --- /dev/null +++ b/piccolo/apps/asgi/commands/templates/app/home/_quart_endpoints.py.jinja @@ -0,0 +1,18 @@ +import os + +import jinja2 + +from quart import Response + +ENVIRONMENT = jinja2.Environment( + loader=jinja2.FileSystemLoader( + searchpath=os.path.join(os.path.dirname(__file__), "templates") + ) +) + + +def index(): + template = ENVIRONMENT.get_template("home.html.jinja") + content = template.render(title="Piccolo + ASGI") + return Response(content) + \ No newline at end of file diff --git a/piccolo/apps/asgi/commands/templates/app/home/_ravyn_endpoints.py.jinja b/piccolo/apps/asgi/commands/templates/app/home/_ravyn_endpoints.py.jinja new file mode 100644 index 000000000..0a400c3df --- /dev/null +++ b/piccolo/apps/asgi/commands/templates/app/home/_ravyn_endpoints.py.jinja @@ -0,0 +1,20 @@ +import os + +import jinja2 +from ravyn import Request, Response, get +from ravyn.responses import HTMLResponse + +ENVIRONMENT = jinja2.Environment( + loader=jinja2.FileSystemLoader( + searchpath=os.path.join(os.path.dirname(__file__), "templates") + ) +) + + +@get(path="/", include_in_schema=False) +def home(request: Request) -> HTMLResponse: + template = ENVIRONMENT.get_template("home.html.jinja") + + content = template.render(title="Piccolo + ASGI",) + + return HTMLResponse(content) diff --git a/piccolo/apps/asgi/commands/templates/app/home/_sanic_endpoints.py.jinja b/piccolo/apps/asgi/commands/templates/app/home/_sanic_endpoints.py.jinja new file mode 100644 index 000000000..e6a5f3416 --- /dev/null +++ b/piccolo/apps/asgi/commands/templates/app/home/_sanic_endpoints.py.jinja @@ -0,0 +1,17 @@ +import os + +import jinja2 + +from sanic import HTTPResponse + +ENVIRONMENT = jinja2.Environment( + loader=jinja2.FileSystemLoader( + searchpath=os.path.join(os.path.dirname(__file__), "templates") + ) +) + + +def index(): + template = ENVIRONMENT.get_template("home.html.jinja") + content = template.render(title="Piccolo + ASGI") + return HTTPResponse(content) \ No newline at end of file diff --git a/piccolo/apps/asgi/commands/templates/app/home/endpoints.py.jinja b/piccolo/apps/asgi/commands/templates/app/home/endpoints.py.jinja index 47e908805..21c3903ab 100644 --- a/piccolo/apps/asgi/commands/templates/app/home/endpoints.py.jinja +++ b/piccolo/apps/asgi/commands/templates/app/home/endpoints.py.jinja @@ -2,4 +2,16 @@ {% include '_starlette_endpoints.py.jinja' %} {% elif router == 'blacksheep' %} {% include '_blacksheep_endpoints.py.jinja' %} +{% elif router == 'litestar' %} + {% include '_litestar_endpoints.py.jinja' %} +{% elif router == 'ravyn' %} + {% include '_ravyn_endpoints.py.jinja' %} +{% elif router == 'lilya' %} + {% include '_lilya_endpoints.py.jinja' %} +{% elif router == 'quart' %} + {% include '_quart_endpoints.py.jinja' %} +{% elif router == 'falcon' %} + {% include '_falcon_endpoints.py.jinja' %} +{% elif router == 'sanic' %} + {% include '_sanic_endpoints.py.jinja' %} {% endif %} diff --git a/piccolo/apps/asgi/commands/templates/app/home/piccolo_app.py.jinja b/piccolo/apps/asgi/commands/templates/app/home/piccolo_app.py.jinja index 7b9eb4f4a..a2ac77d56 100644 --- a/piccolo/apps/asgi/commands/templates/app/home/piccolo_app.py.jinja +++ b/piccolo/apps/asgi/commands/templates/app/home/piccolo_app.py.jinja @@ -16,7 +16,7 @@ APP_CONFIG = AppConfig( migrations_folder_path=os.path.join( CURRENT_DIRECTORY, "piccolo_migrations" ), - table_classes=table_finder(modules=["home.tables"]), + table_classes=table_finder(modules=["home.tables"], exclude_imported=True), migration_dependencies=[], commands=[], ) diff --git a/piccolo/apps/asgi/commands/templates/app/home/templates/home.html.jinja_raw b/piccolo/apps/asgi/commands/templates/app/home/templates/home.html.jinja_raw index 35159a5e0..f5e447a9b 100644 --- a/piccolo/apps/asgi/commands/templates/app/home/templates/home.html.jinja_raw +++ b/piccolo/apps/asgi/commands/templates/app/home/templates/home.html.jinja_raw @@ -8,7 +8,7 @@

Postgres

Make sure you create the database. See the docs for guidance.

-

See piccolo_conf.py for the database settings.

+

See piccolo_conf.py for the database settings.

@@ -26,7 +26,7 @@

Custom Tables

-

An example table called Task exists in tables.py.

+

An example table called Task exists in tables.py.

When you're ready, create a migration, and run it to add the table to the database:

piccolo migrations new home --auto @@ -51,6 +51,36 @@

  • Admin
  • Swagger API
  • +

    Litestar

    + +

    Ravyn

    + +

    Lilya

    + +

    Quart

    + +

    Falcon

    + +

    Sanic

    +
    {% endblock content %} diff --git a/piccolo/apps/asgi/commands/templates/app/main.py.jinja b/piccolo/apps/asgi/commands/templates/app/main.py.jinja index 8fb934fbb..f3a0ce3ef 100644 --- a/piccolo/apps/asgi/commands/templates/app/main.py.jinja +++ b/piccolo/apps/asgi/commands/templates/app/main.py.jinja @@ -18,4 +18,8 @@ if __name__ == "__main__": asyncio.run(serve(app, CustomConfig())) serve(app) + {% elif server == 'granian' %} + import granian + + granian.Granian("app:app", interface="asgi").serve() {% endif %} diff --git a/piccolo/apps/asgi/commands/templates/app/piccolo_conf_test.py.jinja b/piccolo/apps/asgi/commands/templates/app/piccolo_conf_test.py.jinja new file mode 100644 index 000000000..52eaf191c --- /dev/null +++ b/piccolo/apps/asgi/commands/templates/app/piccolo_conf_test.py.jinja @@ -0,0 +1,12 @@ +from piccolo_conf import * # noqa + + +DB = PostgresEngine( + config={ + "database": "{{ project_identifier }}_test", + "user": "postgres", + "password": "", + "host": "localhost", + "port": 5432, + } +) diff --git a/piccolo/apps/asgi/commands/templates/app/requirements.txt.jinja b/piccolo/apps/asgi/commands/templates/app/requirements.txt.jinja index 4b4a52171..828796740 100644 --- a/piccolo/apps/asgi/commands/templates/app/requirements.txt.jinja +++ b/piccolo/apps/asgi/commands/templates/app/requirements.txt.jinja @@ -1,5 +1,6 @@ -{{ router }} +{%- for router_dependency in router_dependencies -%} +{{ router_dependency }} +{% endfor -%} {{ server }} -jinja2 -piccolo[postgres] -piccolo_admin +piccolo[postgres]>=1.0.0 +piccolo_admin>=1.0.0 \ No newline at end of file diff --git a/piccolo/apps/asgi/commands/templates/app/static/main.css b/piccolo/apps/asgi/commands/templates/app/static/main.css index b846e1327..7aff46dbc 100644 --- a/piccolo/apps/asgi/commands/templates/app/static/main.css +++ b/piccolo/apps/asgi/commands/templates/app/static/main.css @@ -4,7 +4,8 @@ body, html { } body { - background-color: #f0f0f0; + background-color: #f0f7fd; + color: #2b475f; font-family: 'Open Sans', sans-serif; } @@ -16,6 +17,7 @@ div.hero { a { color: #4C89C8; + text-decoration: none; } div.hero h1 { @@ -36,20 +38,27 @@ div.content { max-width: 50rem; padding: 2rem; transform: translateY(-4rem); + box-shadow: 0px 1px 1px 1px rgb(0,0,0,0.05); } -div.content h2 { +div.content h2, div.content h3 { font-weight: normal; - border-bottom: 4px solid #f0f0f0; +} + +div.content code { + padding: 2px 4px; + background-color: #f0f7fd; + border-radius: 0.2rem; } p.code { - background-color: #2b2b2b; + background-color: #233d58; color: white; font-family: monospace; padding: 1rem; margin: 0; display: block; + border-radius: 0.2rem; } p.code span { diff --git a/tests/example_app/__init__.py b/piccolo/apps/fixtures/__init__.py similarity index 100% rename from tests/example_app/__init__.py rename to piccolo/apps/fixtures/__init__.py diff --git a/piccolo/apps/fixtures/commands/__init__.py b/piccolo/apps/fixtures/commands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/piccolo/apps/fixtures/commands/dump.py b/piccolo/apps/fixtures/commands/dump.py new file mode 100644 index 000000000..1114a6c86 --- /dev/null +++ b/piccolo/apps/fixtures/commands/dump.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from typing import Any, Optional + +from piccolo.apps.fixtures.commands.shared import ( + FixtureConfig, + create_pydantic_fixture_model, +) +from piccolo.conf.apps import Finder +from piccolo.table import sort_table_classes + + +async def get_dump( + fixture_configs: list[FixtureConfig], +) -> dict[str, Any]: + """ + Gets the data for each table specified and returns a data structure like: + + .. code-block:: python + + { + 'my_app_name': { + 'MyTableName': [ + { + 'id': 1, + 'my_column_name': 'foo' + } + ] + } + } + + """ + finder = Finder() + + output: dict[str, Any] = {} + + for fixture_config in fixture_configs: + app_config = finder.get_app_config(app_name=fixture_config.app_name) + table_classes = [ + i + for i in app_config.table_classes + if i.__name__ in fixture_config.table_class_names + ] + sorted_table_classes = sort_table_classes(table_classes) + + output[fixture_config.app_name] = {} + + for table_class in sorted_table_classes: + data = await table_class.select().run() + output[fixture_config.app_name][table_class.__name__] = data + + return output + + +async def dump_to_json_string( + fixture_configs: list[FixtureConfig], +) -> str: + """ + Dumps all of the data for the given tables into a JSON string. + """ + dump = await get_dump(fixture_configs=fixture_configs) + pydantic_model = create_pydantic_fixture_model( + fixture_configs=fixture_configs + ) + return pydantic_model(**dump).model_dump_json(indent=4) + + +def parse_args(apps: str, tables: str) -> list[FixtureConfig]: + """ + Works out which apps and tables the user is referring to. + """ + finder = Finder() + app_names = [] + + if apps == "all": + app_names = finder.get_sorted_app_names() + elif "," in apps: + app_names = apps.split(",") + else: + # Must be a single app name + app_names.append(apps) + + table_class_names: Optional[list[str]] = None + + if tables != "all": + table_class_names = tables.split(",") if "," in tables else [tables] + output: list[FixtureConfig] = [] + + for app_name in app_names: + app_config = finder.get_app_config(app_name=app_name) + table_classes = app_config.table_classes + + if table_class_names is None: + fixture_configs = [i.__name__ for i in table_classes] + else: + fixture_configs = [ + i.__name__ + for i in table_classes + if i.__name__ in table_class_names + ] + output.append( + FixtureConfig( + app_name=app_name, + table_class_names=fixture_configs, + ) + ) + + return output + + +async def dump(apps: str = "all", tables: str = "all"): + """ + Serialises the data from the given Piccolo apps / tables, and prints it + out. + + :param apps: + For all apps, specify `all`. For specific apps, pass in a comma + separated list e.g. `blog,profiles,billing`. For a single app, just + pass in the name of that app, e.g. `blog`. + :param tables: + For all tables, specify `all`. For specific tables, pass in a comma + separated list e.g. `Post,Tag`. For a single app, just + pass in the name of that app, e.g. `Post`. + + """ + fixture_configs = parse_args(apps=apps, tables=tables) + json_string = await dump_to_json_string(fixture_configs=fixture_configs) + print(json_string) diff --git a/piccolo/apps/fixtures/commands/load.py b/piccolo/apps/fixtures/commands/load.py new file mode 100644 index 000000000..835f814a9 --- /dev/null +++ b/piccolo/apps/fixtures/commands/load.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import sys +from typing import Optional + +import typing_extensions + +from piccolo.apps.fixtures.commands.shared import ( + FixtureConfig, + create_pydantic_fixture_model, +) +from piccolo.conf.apps import Finder +from piccolo.engine import engine_finder +from piccolo.query.mixins import OnConflictAction +from piccolo.table import Table, sort_table_classes +from piccolo.utils.encoding import load_json +from piccolo.utils.list import batch + + +async def load_json_string( + json_string: str, + chunk_size: int = 1000, + on_conflict_action: Optional[OnConflictAction] = None, +): + """ + Parses the JSON string, and inserts the parsed data into the database. + """ + # We have to deserialise the JSON to find out which apps and tables it + # contains, so we can create a Pydantic model. + # Then we let Pydantic do the proper deserialisation, as it does a much + # better job of deserialising dates, datetimes, bytes etc. + deserialised_contents = load_json(json_string) + + app_names = deserialised_contents.keys() + + fixture_configs = [ + FixtureConfig( + app_name=app_name, + table_class_names=list(deserialised_contents[app_name].keys()), + ) + for app_name in app_names + ] + pydantic_model_class = create_pydantic_fixture_model( + fixture_configs=fixture_configs + ) + + fixture_pydantic_model = pydantic_model_class.model_validate_json( + json_string + ) + + finder = Finder() + engine = engine_finder() + + if engine is None: + raise Exception("Unable to find the engine.") + + # This is what we want to the insert into the database: + data: dict[type[Table], list[Table]] = {} + + for app_name in app_names: + app_model = getattr(fixture_pydantic_model, app_name) + + for ( + table_class_name, + model_instance_list, + ) in app_model.__dict__.items(): + table_class = finder.get_table_with_name( + app_name, table_class_name + ) + data[table_class] = [ + table_class.from_dict(row.__dict__) + for row in model_instance_list + ] + + # We have to sort the table classes based on foreign key, so we insert + # the data in the right order. + sorted_table_classes = sort_table_classes(list(data.keys())) + + async with engine.transaction(): + for table_class in sorted_table_classes: + rows = data[table_class] + + for chunk in batch(data=rows, chunk_size=chunk_size): + query = table_class.insert(*chunk) + if on_conflict_action is not None: + query = query.on_conflict( + target=table_class._meta.primary_key, + action=on_conflict_action, + values=table_class._meta.columns, + ) + await query.run() + + +async def load( + path: str = "fixture.json", + chunk_size: int = 1000, + on_conflict: Optional[ + typing_extensions.Literal["DO NOTHING", "DO UPDATE"] + ] = None, +): + """ + Reads the fixture file, and loads the contents into the database. + + :param path: + The path of the fixture file. + + :param chunk_size: + The maximum number of rows to insert at a time. This is usually + determined by the database adapter, which has a max number of + parameters per query. + + :param on_conflict: + If specified, the fixture will be upserted, meaning that if a row + already exists with a matching primary key, then it will be overridden + if "DO UPDATE", or it will be ignored if "DO NOTHING". + + """ + with open(path, "r") as f: + contents = f.read() + + on_conflict_action: Optional[OnConflictAction] = None + + if on_conflict: + try: + on_conflict_action = OnConflictAction(on_conflict.upper()) + except ValueError: + sys.exit( + f"{on_conflict} isn't a valid option - use 'DO NOTHING' or " + "'DO UPDATE'." + ) + + await load_json_string( + contents, + chunk_size=chunk_size, + on_conflict_action=on_conflict_action, + ) diff --git a/piccolo/apps/fixtures/commands/shared.py b/piccolo/apps/fixtures/commands/shared.py new file mode 100644 index 000000000..d2d67a819 --- /dev/null +++ b/piccolo/apps/fixtures/commands/shared.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import pydantic + +from piccolo.conf.apps import Finder +from piccolo.utils.pydantic import create_pydantic_model + +if TYPE_CHECKING: # pragma: no cover + from piccolo.table import Table + + +@dataclass +class FixtureConfig: + app_name: str + table_class_names: list[str] + + +def create_pydantic_fixture_model(fixture_configs: list[FixtureConfig]): + """ + Returns a nested Pydantic model for serialising and deserialising fixtures. + """ + columns: dict[str, Any] = {} + + finder = Finder() + + for fixture_config in fixture_configs: + + app_columns: dict[str, Any] = {} + + for table_class_name in fixture_config.table_class_names: + table_class: type[Table] = finder.get_table_with_name( + app_name=fixture_config.app_name, + table_class_name=table_class_name, + ) + app_columns[table_class_name] = ( + list[ # type: ignore + create_pydantic_model( + table_class, include_default_columns=True + ) + ], + ..., + ) + + app_model: Any = pydantic.create_model( + f"{fixture_config.app_name.title()}Model", **app_columns + ) + + columns[fixture_config.app_name] = (app_model, ...) + + return pydantic.create_model("FixtureModel", **columns) diff --git a/piccolo/apps/fixtures/piccolo_app.py b/piccolo/apps/fixtures/piccolo_app.py new file mode 100644 index 000000000..7468dad89 --- /dev/null +++ b/piccolo/apps/fixtures/piccolo_app.py @@ -0,0 +1,12 @@ +from piccolo.conf.apps import AppConfig + +from .commands.dump import dump +from .commands.load import load + +APP_CONFIG = AppConfig( + app_name="fixtures", + migrations_folder_path="", + table_classes=[], + migration_dependencies=[], + commands=[dump, load], +) diff --git a/piccolo/apps/migrations/auto/__init__.py b/piccolo/apps/migrations/auto/__init__.py index cdffc6c1c..1df58816c 100644 --- a/piccolo/apps/migrations/auto/__init__.py +++ b/piccolo/apps/migrations/auto/__init__.py @@ -2,3 +2,11 @@ from .migration_manager import MigrationManager from .schema_differ import AlterStatements, SchemaDiffer from .schema_snapshot import SchemaSnapshot + +__all__ = [ + "DiffableTable", + "MigrationManager", + "AlterStatements", + "SchemaDiffer", + "SchemaSnapshot", +] diff --git a/piccolo/apps/migrations/auto/diffable_table.py b/piccolo/apps/migrations/auto/diffable_table.py index 199a67cfa..a3e80500f 100644 --- a/piccolo/apps/migrations/auto/diffable_table.py +++ b/piccolo/apps/migrations/auto/diffable_table.py @@ -1,7 +1,7 @@ from __future__ import annotations -import typing as t from dataclasses import dataclass, field +from typing import Any, Optional from piccolo.apps.migrations.auto.operations import ( AddColumn, @@ -16,19 +16,52 @@ from piccolo.table import Table, create_table_class -def compare_dicts(dict_1, dict_2) -> t.Dict[str, t.Any]: +def compare_dicts( + dict_1: dict[str, Any], dict_2: dict[str, Any] +) -> dict[str, Any]: """ Returns a new dictionary which only contains key, value pairs which are in the first dictionary and not the second. + + For example:: + + >>> dict_1 = {'a': 1, 'b': 2} + >>> dict_2 = {'a': 1} + >>> compare_dicts(dict_1, dict_2) + {'b': 2} + + >>> dict_1 = {'a': 2, 'b': 2} + >>> dict_2 = {'a': 1} + >>> compare_dicts(dict_1, dict_2) + {'a': 2, 'b': 2} + """ - return dict(set(dict_1.items()) - set(dict_2.items())) + output = {} + + for key, value in dict_1.items(): + dict_2_value = dict_2.get(key, ...) + + if ( + # If the value is `...` then it means no value was found. + (dict_2_value is ...) + # We have to compare the types, because if we just use equality + # then 1.0 == 1 is True. + # See this issue: + # https://github.com/piccolo-orm/piccolo/issues/1071 + or (type(value) is not type(dict_2_value)) + # Finally compare the actual values. + or (dict_2_value != value) + ): + output[key] = value + + return output @dataclass class TableDelta: - add_columns: t.List[AddColumn] = field(default_factory=list) - drop_columns: t.List[DropColumn] = field(default_factory=list) - alter_columns: t.List[AlterColumn] = field(default_factory=list) + add_columns: list[AddColumn] = field(default_factory=list) + drop_columns: list[DropColumn] = field(default_factory=list) + alter_columns: list[AlterColumn] = field(default_factory=list) def __eq__(self, value: TableDelta) -> bool: # type: ignore """ @@ -52,7 +85,10 @@ def __hash__(self) -> int: def __eq__(self, value) -> bool: if isinstance(value, ColumnComparison): - return self.column._meta.name == value.column._meta.name + return ( + self.column._meta.db_column_name + == value.column._meta.db_column_name + ) return False @@ -65,11 +101,12 @@ class DiffableTable: class_name: str tablename: str - columns: t.List[Column] = field(default_factory=list) - previous_class_name: t.Optional[str] = None + schema: Optional[str] = None + columns: list[Column] = field(default_factory=list) + previous_class_name: Optional[str] = None - def __post_init__(self): - self.columns_map: t.Dict[str, Column] = { + def __post_init__(self) -> None: + self.columns_map: dict[str, Column] = { i._meta.name: i for i in self.columns } @@ -84,17 +121,28 @@ def __sub__(self, value: DiffableTable) -> TableDelta: "The two tables don't appear to have the same name." ) + ####################################################################### + + # Because we're using sets here, the order is indeterminate. We sort + # them, otherwise it's difficult to write good unit tests if the order + # constantly changes. + add_columns = [ AddColumn( table_class_name=self.class_name, column_name=i.column._meta.name, + db_column_name=i.column._meta.db_column_name, column_class_name=i.column.__class__.__name__, column_class=i.column.__class__, params=i.column._meta.params, + schema=self.schema, ) - for i in ( + for i in sorted( {ColumnComparison(column=column) for column in self.columns} - - {ColumnComparison(column=column) for column in value.columns} + - { + ColumnComparison(column=column) for column in value.columns + }, + key=lambda x: x.column._meta.name, ) ] @@ -102,15 +150,20 @@ def __sub__(self, value: DiffableTable) -> TableDelta: DropColumn( table_class_name=self.class_name, column_name=i.column._meta.name, + db_column_name=i.column._meta.db_column_name, tablename=value.tablename, + schema=self.schema, ) - for i in ( + for i in sorted( {ColumnComparison(column=column) for column in value.columns} - - {ColumnComparison(column=column) for column in self.columns} + - {ColumnComparison(column=column) for column in self.columns}, + key=lambda x: x.column._meta.name, ) ] - alter_columns: t.List[AlterColumn] = [] + ####################################################################### + + alter_columns: list[AlterColumn] = [] for existing_column in value.columns: column = self.columns_map.get(existing_column._meta.name) @@ -134,10 +187,12 @@ def __sub__(self, value: DiffableTable) -> TableDelta: table_class_name=self.class_name, tablename=self.tablename, column_name=column._meta.name, + db_column_name=column._meta.db_column_name, params=deserialise_params(delta), old_params=old_params, column_class=column.__class__, old_column_class=existing_column.__class__, + schema=self.schema, ) ) @@ -166,16 +221,14 @@ def __eq__(self, value) -> bool: def __str__(self): return f"{self.class_name} - {self.tablename}" - def to_table_class(self) -> t.Type[Table]: + def to_table_class(self) -> type[Table]: """ Converts the DiffableTable into a Table subclass. """ - _Table: t.Type[Table] = create_table_class( + return create_table_class( class_name=self.class_name, - class_kwargs={"tablename": self.tablename}, + class_kwargs={"tablename": self.tablename, "schema": self.schema}, class_members={ column._meta.name: column for column in self.columns }, ) - - return _Table diff --git a/piccolo/apps/migrations/auto/migration_manager.py b/piccolo/apps/migrations/auto/migration_manager.py index 704c81d76..ce085783f 100644 --- a/piccolo/apps/migrations/auto/migration_manager.py +++ b/piccolo/apps/migrations/auto/migration_manager.py @@ -1,20 +1,31 @@ from __future__ import annotations import inspect -import typing as t +import logging +from collections.abc import Callable, Coroutine from dataclasses import dataclass, field +from typing import Any, Optional, Union from piccolo.apps.migrations.auto.diffable_table import DiffableTable from piccolo.apps.migrations.auto.operations import ( AlterColumn, + ChangeTableSchema, DropColumn, RenameColumn, RenameTable, ) from piccolo.apps.migrations.auto.serialisation import deserialise_params from piccolo.columns import Column, column_types +from piccolo.columns.column_types import ForeignKey, Serial from piccolo.engine import engine_finder -from piccolo.table import Table, create_table_class +from piccolo.query import Query +from piccolo.query.base import DDL +from piccolo.query.constraints import get_fk_constraint_name +from piccolo.schema import SchemaDDLBase +from piccolo.table import Table, create_table_class, sort_table_classes +from piccolo.utils.warnings import colored_warning + +logger = logging.getLogger(__name__) @dataclass @@ -22,18 +33,19 @@ class AddColumnClass: column: Column table_class_name: str tablename: str + schema: Optional[str] @dataclass class AddColumnCollection: - add_columns: t.List[AddColumnClass] = field(default_factory=list) + add_columns: list[AddColumnClass] = field(default_factory=list) def append(self, add_column: AddColumnClass): self.add_columns.append(add_column) def for_table_class_name( self, table_class_name: str - ) -> t.List[AddColumnClass]: + ) -> list[AddColumnClass]: return [ i for i in self.add_columns @@ -42,7 +54,7 @@ def for_table_class_name( def columns_for_table_class_name( self, table_class_name: str - ) -> t.List[Column]: + ) -> list[Column]: return [ i.column for i in self.add_columns @@ -50,20 +62,18 @@ def columns_for_table_class_name( ] @property - def table_class_names(self) -> t.List[str]: - return list(set([i.table_class_name for i in self.add_columns])) + def table_class_names(self) -> list[str]: + return list({i.table_class_name for i in self.add_columns}) @dataclass class DropColumnCollection: - drop_columns: t.List[DropColumn] = field(default_factory=list) + drop_columns: list[DropColumn] = field(default_factory=list) def append(self, drop_column: DropColumn): self.drop_columns.append(drop_column) - def for_table_class_name( - self, table_class_name: str - ) -> t.List[DropColumn]: + def for_table_class_name(self, table_class_name: str) -> list[DropColumn]: return [ i for i in self.drop_columns @@ -71,20 +81,20 @@ def for_table_class_name( ] @property - def table_class_names(self) -> t.List[str]: - return list(set([i.table_class_name for i in self.drop_columns])) + def table_class_names(self) -> list[str]: + return list({i.table_class_name for i in self.drop_columns}) @dataclass class RenameColumnCollection: - rename_columns: t.List[RenameColumn] = field(default_factory=list) + rename_columns: list[RenameColumn] = field(default_factory=list) def append(self, rename_column: RenameColumn): self.rename_columns.append(rename_column) def for_table_class_name( self, table_class_name: str - ) -> t.List[RenameColumn]: + ) -> list[RenameColumn]: return [ i for i in self.rename_columns @@ -92,20 +102,18 @@ def for_table_class_name( ] @property - def table_class_names(self) -> t.List[str]: - return list(set([i.table_class_name for i in self.rename_columns])) + def table_class_names(self) -> list[str]: + return list({i.table_class_name for i in self.rename_columns}) @dataclass class AlterColumnCollection: - alter_columns: t.List[AlterColumn] = field(default_factory=list) + alter_columns: list[AlterColumn] = field(default_factory=list) def append(self, alter_column: AlterColumn): self.alter_columns.append(alter_column) - def for_table_class_name( - self, table_class_name: str - ) -> t.List[AlterColumn]: + def for_table_class_name(self, table_class_name: str) -> list[AlterColumn]: return [ i for i in self.alter_columns @@ -113,8 +121,19 @@ def for_table_class_name( ] @property - def table_class_names(self) -> t.List[str]: - return list(set([i.table_class_name for i in self.alter_columns])) + def table_class_names(self) -> list[str]: + return list({i.table_class_name for i in self.alter_columns}) + + +AsyncFunction = Callable[[], Coroutine] + + +class SkippedTransaction: + async def __aenter__(self): + print("Automatic transaction disabled") + + async def __aexit__(self, *args, **kwargs): + pass @dataclass @@ -122,14 +141,24 @@ class MigrationManager: """ Each auto generated migration returns a MigrationManager. It contains all of the schema changes that migration wants to make. + + :param wrap_in_transaction: + By default, the migration is wrapped in a transaction, so if anything + fails, the whole migration will get rolled back. You can disable this + behaviour if you want - for example, in a manual migration you might + want to create the transaction yourself (perhaps you're using + savepoints), or you may want multiple transactions. + """ migration_id: str = "" app_name: str = "" description: str = "" - add_tables: t.List[DiffableTable] = field(default_factory=list) - drop_tables: t.List[DiffableTable] = field(default_factory=list) - rename_tables: t.List[RenameTable] = field(default_factory=list) + preview: bool = False + add_tables: list[DiffableTable] = field(default_factory=list) + drop_tables: list[DiffableTable] = field(default_factory=list) + rename_tables: list[RenameTable] = field(default_factory=list) + change_table_schemas: list[ChangeTableSchema] = field(default_factory=list) add_columns: AddColumnCollection = field( default_factory=AddColumnCollection ) @@ -142,29 +171,55 @@ class MigrationManager: alter_columns: AlterColumnCollection = field( default_factory=AlterColumnCollection ) - raw: t.List[t.Union[t.Callable, t.Coroutine]] = field(default_factory=list) - raw_backwards: t.List[t.Union[t.Callable, t.Coroutine]] = field( + raw: list[Union[Callable, AsyncFunction]] = field(default_factory=list) + raw_backwards: list[Union[Callable, AsyncFunction]] = field( default_factory=list ) + fake: bool = False + wrap_in_transaction: bool = True def add_table( self, class_name: str, tablename: str, - columns: t.Optional[t.List[Column]] = None, + schema: Optional[str] = None, + columns: Optional[list[Column]] = None, ): if not columns: columns = [] self.add_tables.append( DiffableTable( - class_name=class_name, tablename=tablename, columns=columns + class_name=class_name, + tablename=tablename, + columns=columns, + schema=schema, ) ) - def drop_table(self, class_name: str, tablename: str): + def drop_table( + self, class_name: str, tablename: str, schema: Optional[str] = None + ): self.drop_tables.append( - DiffableTable(class_name=class_name, tablename=tablename) + DiffableTable( + class_name=class_name, tablename=tablename, schema=schema + ) + ) + + def change_table_schema( + self, + class_name: str, + tablename: str, + new_schema: Optional[str] = None, + old_schema: Optional[str] = None, + ): + self.change_table_schemas.append( + ChangeTableSchema( + class_name=class_name, + tablename=tablename, + new_schema=new_schema, + old_schema=old_schema, + ) ) def rename_table( @@ -173,6 +228,7 @@ def rename_table( old_tablename: str, new_class_name: str, new_tablename: str, + schema: Optional[str] = None, ): self.rename_tables.append( RenameTable( @@ -180,6 +236,7 @@ def rename_table( old_tablename=old_tablename, new_class_name=new_class_name, new_tablename=new_tablename, + schema=schema, ) ) @@ -188,9 +245,11 @@ def add_column( table_class_name: str, tablename: str, column_name: str, + db_column_name: Optional[str] = None, column_class_name: str = "", - column_class: t.Optional[t.Type[Column]] = None, - params: t.Dict[str, t.Any] = {}, + column_class: Optional[type[Column]] = None, + params: Optional[dict[str, Any]] = None, + schema: Optional[str] = None, ): """ Add a new column to the table. @@ -204,6 +263,8 @@ def add_column( A direct reference to a ``Column`` subclass. """ + if params is None: + params = {} column_class = column_class or getattr(column_types, column_class_name) if column_class is None: @@ -212,22 +273,33 @@ def add_column( cleaned_params = deserialise_params(params=params) column = column_class(**cleaned_params) column._meta.name = column_name + if db_column_name: + column._meta.db_column_name = db_column_name + self.add_columns.append( AddColumnClass( column=column, tablename=tablename, table_class_name=table_class_name, + schema=schema, ) ) def drop_column( - self, table_class_name: str, tablename: str, column_name: str + self, + table_class_name: str, + tablename: str, + column_name: str, + db_column_name: Optional[str] = None, + schema: Optional[str] = None, ): self.drop_columns.append( DropColumn( table_class_name=table_class_name, column_name=column_name, + db_column_name=db_column_name or column_name, tablename=tablename, + schema=schema, ) ) @@ -237,6 +309,9 @@ def rename_column( tablename: str, old_column_name: str, new_column_name: str, + old_db_column_name: Optional[str] = None, + new_db_column_name: Optional[str] = None, + schema: Optional[str] = None, ): self.rename_columns.append( RenameColumn( @@ -244,6 +319,9 @@ def rename_column( tablename=tablename, old_column_name=old_column_name, new_column_name=new_column_name, + old_db_column_name=old_db_column_name or old_column_name, + new_db_column_name=new_db_column_name or new_column_name, + schema=schema, ) ) @@ -252,34 +330,42 @@ def alter_column( table_class_name: str, tablename: str, column_name: str, - params: t.Dict[str, t.Any], - old_params: t.Dict[str, t.Any], - column_class: t.Optional[t.Type[Column]] = None, - old_column_class: t.Optional[t.Type[Column]] = None, + db_column_name: Optional[str] = None, + params: Optional[dict[str, Any]] = None, + old_params: Optional[dict[str, Any]] = None, + column_class: Optional[type[Column]] = None, + old_column_class: Optional[type[Column]] = None, + schema: Optional[str] = None, ): """ All possible alterations aren't currently supported. """ + if params is None: + params = {} + if old_params is None: + old_params = {} self.alter_columns.append( AlterColumn( table_class_name=table_class_name, tablename=tablename, column_name=column_name, + db_column_name=db_column_name or column_name, params=params, old_params=old_params, column_class=column_class, old_column_class=old_column_class, + schema=schema, ) ) - def add_raw(self, raw: t.Union[t.Callable, t.Coroutine]): + def add_raw(self, raw: Union[Callable, AsyncFunction]): """ A migration manager can execute arbitrary functions or coroutines when run. This is useful if you want to execute raw SQL. """ self.raw.append(raw) - def add_raw_backwards(self, raw: t.Union[t.Callable, t.Coroutine]): + def add_raw_backwards(self, raw: Union[Callable, AsyncFunction]): """ When reversing a migration, you may want to run extra code to help clean up. @@ -288,13 +374,13 @@ def add_raw_backwards(self, raw: t.Union[t.Callable, t.Coroutine]): ########################################################################### - async def get_table_from_snaphot( + async def get_table_from_snapshot( self, table_class_name: str, - app_name: t.Optional[str], + app_name: Optional[str], offset: int = 0, - migration_id: t.Optional[str] = None, - ) -> t.Type[Table]: + migration_id: Optional[str] = None, + ) -> type[Table]: """ Returns a Table subclass which can be used for modifying data within a migration. @@ -312,7 +398,7 @@ async def get_table_from_snaphot( if app_name is None: app_name = self.app_name - diffable_table = await BaseMigrationManager().get_table_from_snaphot( + diffable_table = await BaseMigrationManager().get_table_from_snapshot( app_name=app_name, table_class_name=table_class_name, max_migration_id=migration_id, @@ -322,7 +408,24 @@ async def get_table_from_snaphot( ########################################################################### - async def _run_alter_columns(self, backwards=False): + @staticmethod + async def _print_query(query: Union[DDL, Query, SchemaDDLBase]): + if isinstance(query, DDL): + print("\n", ";".join(query.ddl) + ";") + else: + print(str(query)) + + async def _run_query(self, query: Union[DDL, Query, SchemaDDLBase]): + """ + If MigrationManager is in preview mode then it just print the query + instead of executing it. + """ + if self.preview: + await self._print_query(query) + else: + await query.run() + + async def _run_alter_columns(self, backwards: bool = False): for table_class_name in self.alter_columns.table_class_names: alter_columns = self.alter_columns.for_table_class_name( table_class_name @@ -331,13 +434,15 @@ async def _run_alter_columns(self, backwards=False): if not alter_columns: continue - _Table: t.Type[Table] = create_table_class( + _Table: type[Table] = create_table_class( class_name=table_class_name, - class_kwargs={"tablename": alter_columns[0].tablename}, + class_kwargs={ + "tablename": alter_columns[0].tablename, + "schema": alter_columns[0].schema, + }, ) for alter_column in alter_columns: - params = ( alter_column.old_params if backwards @@ -371,41 +476,121 @@ async def _run_alter_columns(self, backwards=False): old_column = old_column_class(**old_params) old_column._meta._table = _Table old_column._meta._name = alter_column.column_name + old_column._meta.db_column_name = ( + alter_column.db_column_name + ) new_column = column_class(**params) new_column._meta._table = _Table new_column._meta._name = alter_column.column_name - - await _Table.alter().set_column_type( - old_column=old_column, new_column=new_column + new_column._meta.db_column_name = ( + alter_column.db_column_name ) + using_expression: Optional[str] = None + + # Postgres won't automatically cast some types to + # others. We may as well try, as it will definitely + # fail otherwise. + if new_column.value_type != old_column.value_type: + if old_params.get("default", ...) is not None: + # Unless the column's default value is also + # something which can be cast to the new type, + # it will also fail. Drop the default value for + # now - the proper default is set later on. + await self._run_query( + _Table.alter().drop_default(old_column) + ) + + using_expression = "{}::{}".format( + alter_column.db_column_name, + new_column.column_type, + ) + + # We can't migrate a SERIAL to a BIGSERIAL or vice + # versa, as SERIAL isn't a true type, just an alias to + # other commands. + if issubclass(column_class, Serial) and issubclass( + old_column_class, Serial + ): + colored_warning( + "Unable to migrate Serial to BigSerial and " + "vice versa. This must be done manually." + ) + else: + await self._run_query( + _Table.alter().set_column_type( + old_column=old_column, + new_column=new_column, + using_expression=using_expression, + ) + ) + ############################################################### - column_name = alter_column.column_name + on_delete = params.get("on_delete") + on_update = params.get("on_update") + if on_delete is not None or on_update is not None: + existing_table = await self.get_table_from_snapshot( + table_class_name=table_class_name, + app_name=self.app_name, + ) + + fk_column = existing_table._meta.get_column_by_name( + alter_column.column_name + ) + + assert isinstance(fk_column, ForeignKey) + + # First drop the existing foreign key constraint + constraint_name = await get_fk_constraint_name( + column=fk_column + ) + if constraint_name: + await self._run_query( + _Table.alter().drop_constraint( + constraint_name=constraint_name + ) + ) + + # Then add a new foreign key constraint + await self._run_query( + _Table.alter().add_foreign_key_constraint( + column=fk_column, + on_delete=on_delete, + on_update=on_update, + ) + ) null = params.get("null") if null is not None: - await _Table.alter().set_null( - column=column_name, boolean=null - ).run() + await self._run_query( + _Table.alter().set_null( + column=alter_column.db_column_name, boolean=null + ) + ) length = params.get("length") if length is not None: - await _Table.alter().set_length( - column=column_name, length=length - ).run() + await self._run_query( + _Table.alter().set_length( + column=alter_column.db_column_name, length=length + ) + ) unique = params.get("unique") if unique is not None: - # When modifying unique contraints, we need to pass in + # When modifying unique constraints, we need to pass in # a column type, and not just the column name. column = Column() column._meta._table = _Table - column._meta._name = column_name - await _Table.alter().set_unique( - column=column, boolean=unique - ).run() + column._meta._name = alter_column.column_name + column._meta.db_column_name = alter_column.db_column_name + await self._run_query( + _Table.alter().set_unique( + column=column, boolean=unique + ) + ) index = params.get("index") index_method = params.get("index_method") @@ -416,68 +601,86 @@ async def _run_alter_columns(self, backwards=False): # to change the index type. column = Column() column._meta._table = _Table - column._meta._name = column_name - await _Table.drop_index([column]).run() - await _Table.create_index( - [column], method=index_method, if_not_exists=True - ).run() + column._meta._name = alter_column.column_name + column._meta.db_column_name = ( + alter_column.db_column_name + ) + await self._run_query(_Table.drop_index([column])) + await self._run_query( + _Table.create_index( + [column], + method=index_method, + if_not_exists=True, + ) + ) else: # If the index value has changed, then we are either # dropping, or creating an index. column = Column() column._meta._table = _Table - column._meta._name = column_name + column._meta._name = alter_column.column_name + column._meta.db_column_name = alter_column.db_column_name + if index is True: kwargs = ( {"method": index_method} if index_method else {} ) - await _Table.create_index( - [column], if_not_exists=True, **kwargs - ).run() + await self._run_query( + _Table.create_index( + [column], if_not_exists=True, **kwargs + ) + ) else: - await _Table.drop_index([column]).run() + await self._run_query(_Table.drop_index([column])) # None is a valid value, so retrieve ellipsis if not found. default = params.get("default", ...) if default is not ...: column = Column() column._meta._table = _Table - column._meta._name = column_name + column._meta._name = alter_column.column_name + column._meta.db_column_name = alter_column.db_column_name + if default is None: - await _Table.alter().drop_default(column=column).run() + await self._run_query( + _Table.alter().drop_default(column=column) + ) else: column.default = default - await _Table.alter().set_default( - column=column, value=column.get_default_value() - ).run() + await self._run_query( + _Table.alter().set_default( + column=column, value=column.get_default_value() + ) + ) # None is a valid value, so retrieve ellipsis if not found. digits = params.get("digits", ...) if digits is not ...: - await _Table.alter().set_digits( - column=alter_column.column_name, - digits=digits, - ).run() + await self._run_query( + _Table.alter().set_digits( + column=alter_column.db_column_name, + digits=digits, + ) + ) async def _run_drop_tables(self, backwards=False): - if backwards: - for diffable_table in self.drop_tables: - _Table = await self.get_table_from_snaphot( + for diffable_table in self.drop_tables: + if backwards: + _Table = await self.get_table_from_snapshot( table_class_name=diffable_table.class_name, app_name=self.app_name, offset=-1, ) - await _Table.create_table().run() - else: - for diffable_table in self.drop_tables: - await ( - diffable_table.to_table_class().alter().drop_table().run() + await self._run_query(_Table.create_table()) + else: + await self._run_query( + diffable_table.to_table_class().alter().drop_table() ) - async def _run_drop_columns(self, backwards=False): + async def _run_drop_columns(self, backwards: bool = False): if backwards: for drop_column in self.drop_columns.drop_columns: - _Table = await self.get_table_from_snaphot( + _Table = await self.get_table_from_snapshot( table_class_name=drop_column.table_class_name, app_name=self.app_name, offset=-1, @@ -485,9 +688,11 @@ async def _run_drop_columns(self, backwards=False): column_to_restore = _Table._meta.get_column_by_name( drop_column.column_name ) - await _Table.alter().add_column( - name=drop_column.column_name, column=column_to_restore - ).run() + await self._run_query( + _Table.alter().add_column( + name=drop_column.column_name, column=column_to_restore + ) + ) else: for table_class_name in self.drop_columns.table_class_names: columns = self.drop_columns.for_table_class_name( @@ -497,17 +702,20 @@ async def _run_drop_columns(self, backwards=False): if not columns: continue - _Table: t.Type[Table] = create_table_class( + _Table = create_table_class( class_name=table_class_name, - class_kwargs={"tablename": columns[0].tablename}, + class_kwargs={ + "tablename": columns[0].tablename, + "schema": columns[0].schema, + }, ) for column in columns: - await _Table.alter().drop_column( - column=column.column_name - ).run() + await self._run_query( + _Table.alter().drop_column(column=column.column_name) + ) - async def _run_rename_tables(self, backwards=False): + async def _run_rename_tables(self, backwards: bool = False): for rename_table in self.rename_tables: class_name = ( rename_table.new_class_name @@ -525,13 +733,19 @@ async def _run_rename_tables(self, backwards=False): else rename_table.new_tablename ) - _Table: t.Type[Table] = create_table_class( - class_name=class_name, class_kwargs={"tablename": tablename} + _Table: type[Table] = create_table_class( + class_name=class_name, + class_kwargs={ + "tablename": tablename, + "schema": rename_table.schema, + }, ) - await _Table.alter().rename_table(new_name=new_tablename).run() + await self._run_query( + _Table.alter().rename_table(new_name=new_tablename) + ) - async def _run_rename_columns(self, backwards=False): + async def _run_rename_columns(self, backwards: bool = False): for table_class_name in self.rename_columns.table_class_names: columns = self.rename_columns.for_table_class_name( table_class_name @@ -540,51 +754,63 @@ async def _run_rename_columns(self, backwards=False): if not columns: continue - _Table: t.Type[Table] = create_table_class( + _Table: type[Table] = create_table_class( class_name=table_class_name, - class_kwargs={"tablename": columns[0].tablename}, + class_kwargs={ + "tablename": columns[0].tablename, + "schema": columns[0].schema, + }, ) for rename_column in columns: column = ( - rename_column.new_column_name + rename_column.new_db_column_name if backwards - else rename_column.old_column_name + else rename_column.old_db_column_name ) new_name = ( - rename_column.old_column_name + rename_column.old_db_column_name if backwards - else rename_column.new_column_name + else rename_column.new_db_column_name ) - await _Table.alter().rename_column( - column=column, - new_name=new_name, - ).run() + await self._run_query( + _Table.alter().rename_column( + column=column, + new_name=new_name, + ) + ) + + async def _run_add_tables(self, backwards: bool = False): + table_classes: list[type[Table]] = [] + for add_table in self.add_tables: + add_columns: list[AddColumnClass] = ( + self.add_columns.for_table_class_name(add_table.class_name) + ) + _Table: type[Table] = create_table_class( + class_name=add_table.class_name, + class_kwargs={ + "tablename": add_table.tablename, + "schema": add_table.schema, + }, + class_members={ + add_column.column._meta.name: add_column.column + for add_column in add_columns + }, + ) + table_classes.append(_Table) + + # Sort by foreign key, so they're created in the right order. + sorted_table_classes = sort_table_classes(table_classes) - async def _run_add_tables(self, backwards=False): if backwards: - for add_table in self.add_tables: - await add_table.to_table_class().alter().drop_table( - cascade=True - ).run() + for _Table in reversed(sorted_table_classes): + await self._run_query(_Table.alter().drop_table(cascade=True)) else: - for add_table in self.add_tables: - add_columns: t.List[ - AddColumnClass - ] = self.add_columns.for_table_class_name(add_table.class_name) - _Table: t.Type[Table] = create_table_class( - class_name=add_table.class_name, - class_kwargs={"tablename": add_table.tablename}, - class_members={ - add_column.column._meta.name: add_column.column - for add_column in add_columns - }, - ) - - await _Table.create_table().run() + for _Table in sorted_table_classes: + await self._run_query(_Table.create_table()) - async def _run_add_columns(self, backwards=False): + async def _run_add_columns(self, backwards: bool = False): """ Add columns, which belong to existing tables """ @@ -597,88 +823,192 @@ async def _run_add_columns(self, backwards=False): # be deleted. continue - _Table: t.Type[Table] = create_table_class( + _Table = create_table_class( class_name=add_column.table_class_name, - class_kwargs={"tablename": add_column.tablename}, + class_kwargs={ + "tablename": add_column.tablename, + "schema": add_column.schema, + }, ) - await _Table.alter().drop_column(add_column.column).run() + await self._run_query( + _Table.alter().drop_column(add_column.column) + ) else: for table_class_name in self.add_columns.table_class_names: if table_class_name in [i.class_name for i in self.add_tables]: continue # No need to add columns to new tables - add_columns: t.List[ - AddColumnClass - ] = self.add_columns.for_table_class_name(table_class_name) + add_columns: list[AddColumnClass] = ( + self.add_columns.for_table_class_name(table_class_name) + ) + ############################################################### # Define the table, with the columns, so the metaclass # sets up the columns correctly. - _Table: t.Type[Table] = create_table_class( + + table_class_members = { + add_column.column._meta.name: add_column.column + for add_column in add_columns + } + + # There's an extreme edge case, when we're adding a foreign + # key which references its own table, for example: + # + # fk = ForeignKey('self') + # + # And that table has a custom primary key, for example: + # + # id = UUID(primary_key=True) + # + # In this situation, we need to know the primary key of the + # table in order to correctly add this new foreign key. + for add_column in add_columns: + if ( + isinstance(add_column.column, ForeignKey) + and add_column.column._meta.params.get("references") + == "self" + ): + try: + existing_table = ( + await self.get_table_from_snapshot( + table_class_name=table_class_name, + app_name=self.app_name, + offset=-1, + ) + ) + except ValueError: + logger.error( + "Unable to find primary key for the table - " + "assuming Serial." + ) + else: + primary_key = existing_table._meta.primary_key + + table_class_members[primary_key._meta.name] = ( + primary_key + ) + + break + + _Table = create_table_class( class_name=add_columns[0].table_class_name, - class_kwargs={"tablename": add_columns[0].tablename}, - class_members={ - add_column.column._meta.name: add_column.column - for add_column in add_columns + class_kwargs={ + "tablename": add_columns[0].tablename, + "schema": add_columns[0].schema, }, + class_members=table_class_members, ) + ############################################################### + for add_column in add_columns: # We fetch the column from the Table, as the metaclass # copies and sets it up properly. column = _Table._meta.get_column_by_name( add_column.column._meta.name ) - await _Table.alter().add_column( - name=column._meta.name, column=column - ).run() + + await self._run_query( + _Table.alter().add_column( + name=column._meta.name, column=column + ) + ) if add_column.column._meta.index: - await _Table.create_index([add_column.column]).run() + await self._run_query( + _Table.create_index([add_column.column]) + ) - async def run(self): - print("Running MigrationManager ...") + async def _run_change_table_schema(self, backwards: bool = False): + from piccolo.schema import SchemaManager - engine = engine_finder() + schema_manager = SchemaManager() - if not engine: - raise Exception("Can't find engine") + for change_table_schema in self.change_table_schemas: + if backwards: + # Note, we don't try dropping any schemas we may have created. + # It's dangerous to do so, just in case the user manually + # added tables etc to the schema, and we delete them. - async with engine.transaction(): + if ( + change_table_schema.old_schema + and change_table_schema.old_schema != "public" + ): + await self._run_query( + schema_manager.create_schema( + schema_name=change_table_schema.old_schema, + if_not_exists=True, + ) + ) - for raw in self.raw: - if inspect.iscoroutinefunction(raw): - await raw() - else: - raw() + await self._run_query( + schema_manager.move_table( + table_name=change_table_schema.tablename, + new_schema=change_table_schema.old_schema or "public", + current_schema=change_table_schema.new_schema, + ) + ) - await self._run_add_tables() - await self._run_rename_tables() - await self._run_add_columns() - await self._run_drop_columns() - await self._run_drop_tables() - await self._run_rename_columns() - await self._run_alter_columns() + else: + if ( + change_table_schema.new_schema + and change_table_schema.new_schema != "public" + ): + await self._run_query( + schema_manager.create_schema( + schema_name=change_table_schema.new_schema, + if_not_exists=True, + ) + ) - async def run_backwards(self): - print("Reversing MigrationManager ...") + await self._run_query( + schema_manager.move_table( + table_name=change_table_schema.tablename, + new_schema=change_table_schema.new_schema or "public", + current_schema=change_table_schema.old_schema, + ) + ) + + async def run(self, backwards: bool = False): + direction = "backwards" if backwards else "forwards" + if self.preview: + direction = "preview " + direction + print(f" - {self.migration_id} [{direction}]... ", end="") engine = engine_finder() if not engine: raise Exception("Can't find engine") - async with engine.transaction(): - - for raw in self.raw_backwards: - if inspect.iscoroutinefunction(raw): - await raw() + async with ( + engine.transaction() + if self.wrap_in_transaction + else SkippedTransaction() + ): + if not self.preview: + if direction == "backwards": + raw_list = self.raw_backwards else: - raw() - - await self._run_add_columns(backwards=True) - await self._run_add_tables(backwards=True) - await self._run_drop_tables(backwards=True) - await self._run_rename_tables(backwards=True) - await self._run_drop_columns(backwards=True) - await self._run_rename_columns(backwards=True) - await self._run_alter_columns(backwards=True) + raw_list = self.raw + + for raw in raw_list: + if inspect.iscoroutinefunction(raw): + await raw() + else: + raw() + + await self._run_add_tables(backwards=backwards) + await self._run_change_table_schema(backwards=backwards) + await self._run_rename_tables(backwards=backwards) + await self._run_add_columns(backwards=backwards) + await self._run_drop_columns(backwards=backwards) + await self._run_drop_tables(backwards=backwards) + await self._run_rename_columns(backwards=backwards) + # We can remove this for cockroach when resolved. + # https://github.com/cockroachdb/cockroach/issues/49351 + # "ALTER COLUMN TYPE is not supported inside a transaction" + if engine.engine_type != "cockroach": + await self._run_alter_columns(backwards=backwards) + + if engine.engine_type == "cockroach": + await self._run_alter_columns(backwards=backwards) diff --git a/piccolo/apps/migrations/auto/operations.py b/piccolo/apps/migrations/auto/operations.py index 3da993e13..84e0d261a 100644 --- a/piccolo/apps/migrations/auto/operations.py +++ b/piccolo/apps/migrations/auto/operations.py @@ -1,5 +1,5 @@ -import typing as t from dataclasses import dataclass +from typing import Any, Optional from piccolo.columns.base import Column @@ -10,6 +10,15 @@ class RenameTable: old_tablename: str new_class_name: str new_tablename: str + schema: Optional[str] = None + + +@dataclass +class ChangeTableSchema: + class_name: str + tablename: str + old_schema: Optional[str] + new_schema: Optional[str] @dataclass @@ -18,30 +27,39 @@ class RenameColumn: tablename: str old_column_name: str new_column_name: str + old_db_column_name: str + new_db_column_name: str + schema: Optional[str] = None @dataclass class AlterColumn: table_class_name: str column_name: str + db_column_name: str tablename: str - params: t.Dict[str, t.Any] - old_params: t.Dict[str, t.Any] - column_class: t.Optional[t.Type[Column]] = None - old_column_class: t.Optional[t.Type[Column]] = None + params: dict[str, Any] + old_params: dict[str, Any] + column_class: Optional[type[Column]] = None + old_column_class: Optional[type[Column]] = None + schema: Optional[str] = None @dataclass class DropColumn: table_class_name: str column_name: str + db_column_name: str tablename: str + schema: Optional[str] = None @dataclass class AddColumn: table_class_name: str column_name: str + db_column_name: str column_class_name: str - column_class: t.Type[Column] - params: t.Dict[str, t.Any] + column_class: type[Column] + params: dict[str, Any] + schema: Optional[str] = None diff --git a/piccolo/apps/migrations/auto/schema_differ.py b/piccolo/apps/migrations/auto/schema_differ.py index d8891dd32..7dbc9a469 100644 --- a/piccolo/apps/migrations/auto/schema_differ.py +++ b/piccolo/apps/migrations/auto/schema_differ.py @@ -1,21 +1,33 @@ from __future__ import annotations -import typing as t +import inspect +from collections.abc import Callable from copy import deepcopy from dataclasses import dataclass, field +from typing import Any, Optional from piccolo.apps.migrations.auto.diffable_table import ( DiffableTable, TableDelta, ) -from piccolo.apps.migrations.auto.operations import RenameColumn, RenameTable -from piccolo.apps.migrations.auto.serialisation import Import, serialise_params +from piccolo.apps.migrations.auto.migration_manager import MigrationManager +from piccolo.apps.migrations.auto.operations import ( + ChangeTableSchema, + RenameColumn, + RenameTable, +) +from piccolo.apps.migrations.auto.serialisation import ( + Definition, + Import, + UniqueGlobalNames, + serialise_params, +) from piccolo.utils.printing import get_fixed_length_string @dataclass class RenameTableCollection: - rename_tables: t.List[RenameTable] = field(default_factory=list) + rename_tables: list[RenameTable] = field(default_factory=list) def append(self, renamed_table: RenameTable): self.rename_tables.append(renamed_table) @@ -28,29 +40,43 @@ def old_class_names(self): def new_class_names(self): return [i.new_class_name for i in self.rename_tables] - def renamed_from(self, new_class_name: str) -> t.Optional[str]: + def was_renamed_from(self, old_class_name: str) -> bool: + """ + Returns ``True`` if the given class name was renamed. + """ + for rename_table in self.rename_tables: + if rename_table.old_class_name == old_class_name: + return True + return False + + def renamed_from(self, new_class_name: str) -> Optional[str]: """ Returns the old class name, if it exists. """ rename = [ i for i in self.rename_tables if i.new_class_name == new_class_name ] - if len(rename) > 0: - return rename[0].old_class_name - else: - return None + return rename[0].old_class_name if rename else None + + +@dataclass +class ChangeTableSchemaCollection: + collection: list[ChangeTableSchema] = field(default_factory=list) + + def append(self, change_table_schema: ChangeTableSchema): + self.collection.append(change_table_schema) @dataclass class RenameColumnCollection: - rename_columns: t.List[RenameColumn] = field(default_factory=list) + rename_columns: list[RenameColumn] = field(default_factory=list) def append(self, rename_column: RenameColumn): self.rename_columns.append(rename_column) def for_table_class_name( self, table_class_name: str - ) -> t.List[RenameColumn]: + ) -> list[RenameColumn]: return [ i for i in self.rename_columns @@ -68,9 +94,15 @@ def new_column_names(self): @dataclass class AlterStatements: - statements: t.List[str] - extra_imports: t.List[Import] = field(default_factory=list) - extra_definitions: t.List[str] = field(default_factory=list) + statements: list[str] = field(default_factory=list) + extra_imports: list[Import] = field(default_factory=list) + extra_definitions: list[Definition] = field(default_factory=list) + + def extend(self, alter_statements: AlterStatements): + self.statements.extend(alter_statements.statements) + self.extra_imports.extend(alter_statements.extra_imports) + self.extra_definitions.extend(alter_statements.extra_definitions) + return self @dataclass @@ -81,21 +113,24 @@ class SchemaDiffer: sure - for example, whether a column was renamed. """ - schema: t.List[DiffableTable] - schema_snapshot: t.List[DiffableTable] + schema: list[DiffableTable] + schema_snapshot: list[DiffableTable] # Sometimes the SchemaDiffer requires input from a user - for example, # asking if a table was renamed or not. When running in non-interactive # mode (like in a unittest), we can set a default to be used instead, like # 'y'. - auto_input: t.Optional[str] = None + auto_input: Optional[str] = None ########################################################################### - def __post_init__(self): - self.schema_snapshot_map: t.Dict[str, DiffableTable] = { + def __post_init__(self) -> None: + self.schema_snapshot_map: dict[str, DiffableTable] = { i.class_name: i for i in self.schema_snapshot } + self.table_schema_changes_collection = ( + self.check_table_schema_changes() + ) self.rename_tables_collection = self.check_rename_tables() self.rename_columns_collection = self.check_renamed_columns() @@ -103,11 +138,11 @@ def check_rename_tables(self) -> RenameTableCollection: """ Work out whether any of the tables were renamed. """ - drop_tables: t.List[DiffableTable] = list( + drop_tables: list[DiffableTable] = list( set(self.schema_snapshot) - set(self.schema) ) - new_tables: t.List[DiffableTable] = list( + new_tables: list[DiffableTable] = list( set(self.schema) - set(self.schema_snapshot) ) @@ -123,9 +158,23 @@ def check_rename_tables(self) -> RenameTableCollection: # A renamed table should have at least one column remaining with the # same name. for new_table in new_tables: - new_column_names = [i._meta.name for i in new_table.columns] + new_column_names = [ + i._meta.db_column_name for i in new_table.columns + ] for drop_table in drop_tables: - drop_column_names = [i._meta.name for i in new_table.columns] + if collection.was_renamed_from( + old_class_name=drop_table.class_name + ): + # We've already detected a table that was renamed from + # this, so we can continue. + # This can happen if we're renaming lots of tables in a + # single migration. + # https://github.com/piccolo-orm/piccolo/discussions/832 + continue + + drop_column_names = [ + i._meta.db_column_name for i in new_table.columns + ] same_column_names = set(new_column_names).intersection( drop_column_names ) @@ -143,9 +192,10 @@ def check_rename_tables(self) -> RenameTableCollection: old_tablename=drop_table.tablename, new_class_name=new_table.class_name, new_tablename=new_table.tablename, + schema=new_table.schema, ) ) - continue + break user_response = ( self.auto_input @@ -164,8 +214,32 @@ def check_rename_tables(self) -> RenameTableCollection: old_tablename=drop_table.tablename, new_class_name=new_table.class_name, new_tablename=new_table.tablename, + schema=new_table.schema, ) ) + break + + return collection + + def check_table_schema_changes(self) -> ChangeTableSchemaCollection: + collection = ChangeTableSchemaCollection() + + for table in self.schema: + snapshot_table = self.schema_snapshot_map.get( + table.class_name, None + ) + if not snapshot_table: + continue + + if table.schema != snapshot_table.schema: + collection.append( + ChangeTableSchema( + class_name=table.class_name, + tablename=table.tablename, + new_schema=table.schema, + old_schema=snapshot_table.schema, + ) + ) return collection @@ -191,42 +265,91 @@ def check_renamed_columns(self) -> RenameColumnCollection: # type. For now, each time a column is added and removed from a # table, ask if it's a rename. - renamed_column_names: t.List[str] = [] + # We track which dropped columns have already been identified by + # the user as renames, so we don't ask them if another column + # was also renamed from it. + used_drop_column_names: list[str] = [] for add_column in delta.add_columns: - if add_column.table_class_name in renamed_column_names: - continue - for drop_column in delta.drop_columns: - user_response = ( - self.auto_input - if self.auto_input - else input( - f"Did you rename the `{drop_column.column_name}` " - f"column to `{add_column.column_name}` on the " - f"`{ add_column.table_class_name }` table? (y/N)" - ) + if drop_column.column_name in used_drop_column_names: + continue + + user_response = self.auto_input or input( + f"Did you rename the `{drop_column.db_column_name}` " # noqa: E501 + f"column to `{add_column.db_column_name}` on the " + f"`{add_column.table_class_name}` table? (y/N)" ) if user_response.lower() == "y": - renamed_column_names.append( - add_column.table_class_name - ) + used_drop_column_names.append(drop_column.column_name) collection.append( RenameColumn( table_class_name=add_column.table_class_name, tablename=drop_column.tablename, old_column_name=drop_column.column_name, new_column_name=add_column.column_name, + old_db_column_name=drop_column.db_column_name, + new_db_column_name=add_column.db_column_name, + schema=add_column.schema, ) ) + break return collection ########################################################################### + def _stringify_func( + self, + func: Callable, + params: dict[str, Any], + prefix: Optional[str] = None, + ) -> AlterStatements: + """ + Generates a string representing how to call the given function with the + give params. For example:: + + def my_callable(arg_1: str, arg_2: str): + ... + + >>> _stringify_func( + ... my_callable, + ... {"arg_1": "a", "arg_2": "b"} + ... ).statements + ['my_callable(arg_1="a", arg_2="b")'] + + """ + signature = inspect.signature(func) + + if "self" in signature.parameters.keys(): + params["self"] = None + + serialised_params = serialise_params(params) + + func_name = func.__name__ + + # This will raise an exception is we're missing parameters, which helps + # with debugging: + bound = signature.bind(**serialised_params.params) + bound.apply_defaults() + + args = bound.arguments + if "self" in args: + args.pop("self") + + args_str = ", ".join(f"{i}={repr(j)}" for i, j in args.items()) + + return AlterStatements( + statements=[f"{prefix or ''}{func_name}({args_str})"], + extra_definitions=serialised_params.extra_definitions, + extra_imports=serialised_params.extra_imports, + ) + + ########################################################################### + @property def create_tables(self) -> AlterStatements: - new_tables: t.List[DiffableTable] = list( + new_tables: list[DiffableTable] = list( set(self.schema) - set(self.schema_snapshot) ) @@ -238,16 +361,26 @@ def create_tables(self) -> AlterStatements: not in self.rename_tables_collection.new_class_names ] - return AlterStatements( - statements=[ - f"manager.add_table('{i.class_name}', tablename='{i.tablename}')" # noqa: E501 - for i in new_tables - ] - ) + alter_statements = AlterStatements() + + for i in new_tables: + alter_statements.extend( + self._stringify_func( + func=MigrationManager.add_table, + params={ + "class_name": i.class_name, + "tablename": i.tablename, + "schema": i.schema, + }, + prefix="manager.", + ) + ) + + return alter_statements @property def drop_tables(self) -> AlterStatements: - drop_tables: t.List[DiffableTable] = list( + drop_tables: list[DiffableTable] = list( set(self.schema_snapshot) - set(self.schema) ) @@ -259,27 +392,58 @@ def drop_tables(self) -> AlterStatements: not in self.rename_tables_collection.old_class_names ] - return AlterStatements( - statements=[ - f"manager.drop_table(class_name='{i.class_name}', tablename='{i.tablename}')" # noqa: E501 - for i in drop_tables - ] - ) + alter_statements = AlterStatements() + + for i in drop_tables: + alter_statements.extend( + self._stringify_func( + func=MigrationManager.drop_table, + params={ + "class_name": i.class_name, + "tablename": i.tablename, + "schema": i.schema, + }, + prefix="manager.", + ) + ) + + return alter_statements @property def rename_tables(self) -> AlterStatements: - return AlterStatements( - statements=[ - f"manager.rename_table(old_class_name='{renamed_table.old_class_name}', old_tablename='{renamed_table.old_tablename}', new_class_name='{renamed_table.new_class_name}', new_tablename='{renamed_table.new_tablename}')" # noqa - for renamed_table in self.rename_tables_collection.rename_tables # noqa: E501 - ] - ) + alter_statements = AlterStatements() + + for i in self.rename_tables_collection.rename_tables: + alter_statements.extend( + self._stringify_func( + func=MigrationManager.rename_table, + params=i.__dict__, + prefix="manager.", + ) + ) + + return alter_statements + + @property + def change_table_schemas(self) -> AlterStatements: + alter_statements = AlterStatements() + + for i in self.table_schema_changes_collection.collection: + alter_statements.extend( + self._stringify_func( + func=MigrationManager.change_table_schema, + params=i.__dict__, + prefix="manager.", + ) + ) + + return alter_statements ########################################################################### def _get_snapshot_table( self, table_class_name: str - ) -> t.Optional[DiffableTable]: + ) -> Optional[DiffableTable]: snapshot_table = self.schema_snapshot_map.get(table_class_name, None) if snapshot_table: return snapshot_table @@ -291,17 +455,18 @@ def _get_snapshot_table( class_name = self.rename_tables_collection.renamed_from( table_class_name ) - snapshot_table = self.schema_snapshot_map.get(class_name) - if snapshot_table: - snapshot_table.class_name = table_class_name - return snapshot_table + if class_name: + snapshot_table = self.schema_snapshot_map.get(class_name) + if snapshot_table: + snapshot_table.class_name = table_class_name + return snapshot_table return None @property def alter_columns(self) -> AlterStatements: - response: t.List[str] = [] - extra_imports: t.List[Import] = [] - extra_definitions: t.List[str] = [] + response: list[str] = [] + extra_imports: list[Import] = [] + extra_definitions: list[Definition] = [] for table in self.schema: snapshot_table = self._get_snapshot_table(table.class_name) if snapshot_table: @@ -335,6 +500,11 @@ def alter_columns(self) -> AlterStatements: Import( module=alter_column.column_class.__module__, target=alter_column.column_class.__name__, + expect_conflict_with_global_name=getattr( + UniqueGlobalNames, + f"COLUMN_{alter_column.column_class.__name__.upper()}", # noqa: E501 + None, + ), ) ) @@ -343,11 +513,21 @@ def alter_columns(self) -> AlterStatements: Import( module=alter_column.old_column_class.__module__, target=alter_column.old_column_class.__name__, + expect_conflict_with_global_name=getattr( + UniqueGlobalNames, + f"COLUMN_{alter_column.old_column_class.__name__.upper()}", # noqa: E501 + ), ) ) + schema_str = ( + "None" + if alter_column.schema is None + else f'"{alter_column.schema}"' + ) + response.append( - f"manager.alter_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{alter_column.column_name}', params={new_params.params}, old_params={old_params.params}, column_class={column_class}, old_column_class={old_column_class})" # noqa: E501 + f"manager.alter_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{alter_column.column_name}', db_column_name='{alter_column.db_column_name}', params={new_params.params}, old_params={old_params.params}, column_class={column_class}, old_column_class={old_column_class}, schema={schema_str})" # noqa: E501 ) return AlterStatements( @@ -373,16 +553,20 @@ def drop_columns(self) -> AlterStatements: ): continue + schema_str = ( + "None" if column.schema is None else f'"{column.schema}"' + ) + response.append( - f"manager.drop_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column.column_name}')" # noqa: E501 + f"manager.drop_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column.column_name}', db_column_name='{column.db_column_name}', schema={schema_str})" # noqa: E501 ) return AlterStatements(statements=response) @property def add_columns(self) -> AlterStatements: - response: t.List[str] = [] - extra_imports: t.List[Import] = [] - extra_definitions: t.List[str] = [] + response: list[str] = [] + extra_imports: list[Import] = [] + extra_definitions: list[Definition] = [] for table in self.schema: snapshot_table = self._get_snapshot_table(table.class_name) if snapshot_table: @@ -407,11 +591,22 @@ def add_columns(self) -> AlterStatements: Import( module=column_class.__module__, target=column_class.__name__, + expect_conflict_with_global_name=getattr( + UniqueGlobalNames, + f"COLUMN_{column_class.__name__.upper()}", + None, + ), ) ) + schema_str = ( + "None" + if add_column.schema is None + else f'"{add_column.schema}"' + ) + response.append( - f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{add_column.column_name}', column_class_name='{add_column.column_class_name}', column_class={column_class.__name__}, params={str(cleaned_params)})" # noqa: E501 + f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{add_column.column_name}', db_column_name='{add_column.db_column_name}', column_class_name='{add_column.column_class_name}', column_class={column_class.__name__}, params={str(cleaned_params)}, schema={schema_str})" # noqa: E501 ) return AlterStatements( statements=response, @@ -421,24 +616,30 @@ def add_columns(self) -> AlterStatements: @property def rename_columns(self) -> AlterStatements: - return AlterStatements( - statements=[ - f"manager.rename_column(table_class_name='{i.table_class_name}', tablename='{i.tablename}', old_column_name='{i.old_column_name}', new_column_name='{i.new_column_name}')" # noqa: E501 - for i in self.rename_columns_collection.rename_columns - ] - ) + alter_statements = AlterStatements() + + for i in self.rename_columns_collection.rename_columns: + alter_statements.extend( + self._stringify_func( + func=MigrationManager.rename_column, + params=i.__dict__, + prefix="manager.", + ) + ) + + return alter_statements ########################################################################### @property def new_table_columns(self) -> AlterStatements: - new_tables: t.List[DiffableTable] = list( + new_tables: list[DiffableTable] = list( set(self.schema) - set(self.schema_snapshot) ) - response: t.List[str] = [] - extra_imports: t.List[Import] = [] - extra_definitions: t.List[str] = [] + response: list[str] = [] + extra_imports: list[Import] = [] + extra_definitions: list[Definition] = [] for table in new_tables: if ( table.class_name @@ -458,11 +659,20 @@ def new_table_columns(self) -> AlterStatements: Import( module=column.__class__.__module__, target=column.__class__.__name__, + expect_conflict_with_global_name=getattr( + UniqueGlobalNames, + f"COLUMN_{column.__class__.__name__.upper()}", + None, + ), ) ) + schema_str = ( + "None" if table.schema is None else f'"{table.schema}"' + ) + response.append( - f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column._meta.name}', column_class_name='{column.__class__.__name__}', column_class={column.__class__.__name__}, params={str(cleaned_params)})" # noqa: E501 + f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column._meta.name}', db_column_name='{column._meta.db_column_name}', column_class_name='{column.__class__.__name__}', column_class={column.__class__.__name__}, params={str(cleaned_params)}, schema={schema_str})" # noqa: E501 ) return AlterStatements( statements=response, @@ -472,14 +682,15 @@ def new_table_columns(self) -> AlterStatements: ########################################################################### - def get_alter_statements(self) -> t.List[AlterStatements]: + def get_alter_statements(self) -> list[AlterStatements]: """ Call to execute the necessary alter commands on the database. """ - alter_statements: t.Dict[str, AlterStatements] = { + alter_statements: dict[str, AlterStatements] = { "Created tables": self.create_tables, "Dropped tables": self.drop_tables, "Renamed tables": self.rename_tables, + "Tables which changed schema": self.change_table_schemas, "Created table columns": self.new_table_columns, "Dropped columns": self.drop_columns, "Columns added to existing tables": self.add_columns, @@ -492,4 +703,4 @@ def get_alter_statements(self) -> t.List[AlterStatements]: count = len(statements.statements) print(f"{_message} {count}") - return [i for i in alter_statements.values()] + return list(alter_statements.values()) diff --git a/piccolo/apps/migrations/auto/schema_snapshot.py b/piccolo/apps/migrations/auto/schema_snapshot.py index 8e1a06fb9..5bf343063 100644 --- a/piccolo/apps/migrations/auto/schema_snapshot.py +++ b/piccolo/apps/migrations/auto/schema_snapshot.py @@ -1,6 +1,5 @@ from __future__ import annotations -import typing as t from dataclasses import dataclass, field from piccolo.apps.migrations.auto.diffable_table import DiffableTable @@ -15,21 +14,21 @@ class SchemaSnapshot: """ # In ascending order of date created. - managers: t.List[MigrationManager] = field(default_factory=list) + managers: list[MigrationManager] = field(default_factory=list) ########################################################################### def get_table_from_snapshot(self, table_class_name: str) -> DiffableTable: snapshot = self.get_snapshot() filtered = [i for i in snapshot if i.class_name == table_class_name] - if len(filtered) == 0: + if not filtered: raise ValueError(f"No match was found for {table_class_name}") return filtered[0] ########################################################################### - def get_snapshot(self) -> t.List[DiffableTable]: - tables: t.List[DiffableTable] = [] + def get_snapshot(self) -> list[DiffableTable]: + tables: list[DiffableTable] = [] # Make sure the managers are sorted correctly: sorted_managers = sorted(self.managers, key=lambda x: x.migration_id) @@ -50,6 +49,12 @@ def get_snapshot(self) -> t.List[DiffableTable]: table.tablename = rename_table.new_tablename break + for change_table_schema in manager.change_table_schemas: + for table in tables: + if table.tablename == change_table_schema.tablename: + table.schema = change_table_schema.new_schema + break + for table in tables: add_columns = manager.add_columns.columns_for_table_class_name( table.class_name @@ -85,13 +90,12 @@ def get_snapshot(self) -> t.List[DiffableTable]: if ( alter_column.column_class != alter_column.old_column_class - ): - if alter_column.column_class is not None: - new_column = alter_column.column_class( - **column._meta.params - ) - new_column._meta = column._meta - table.columns[index] = new_column + ) and alter_column.column_class is not None: + new_column = alter_column.column_class( + **column._meta.params + ) + new_column._meta = column._meta + table.columns[index] = new_column ############################################################### @@ -103,5 +107,8 @@ def get_snapshot(self) -> t.List[DiffableTable]: for column in table.columns: if column._meta.name == rename_column.old_column_name: column._meta.name = rename_column.new_column_name + column._meta.db_column_name = ( + rename_column.new_db_column_name + ) return tables diff --git a/piccolo/apps/migrations/auto/serialisation.py b/piccolo/apps/migrations/auto/serialisation.py index bc97ca44c..e3d166353 100644 --- a/piccolo/apps/migrations/auto/serialisation.py +++ b/piccolo/apps/migrations/auto/serialisation.py @@ -1,14 +1,17 @@ from __future__ import annotations +import abc import builtins import datetime import decimal import inspect -import typing as t import uuid +import warnings +from collections.abc import Callable, Iterable from copy import deepcopy from dataclasses import dataclass, field from enum import Enum +from typing import Any, Optional from piccolo.columns import Column from piccolo.columns.defaults.base import Default @@ -21,15 +24,263 @@ ############################################################################### +class CanConflictWithGlobalNames(abc.ABC): + @abc.abstractmethod + def warn_if_is_conflicting_with_global_name(self): ... + + +class UniqueGlobalNamesMeta(type): + """ + Metaclass for ``UniqueGlobalNames``. + + Fulfills the following functions: + + - Assure that no two class attributes have the same value. + - Add class attributes `COLUMN_` + to the class for each column type. + """ + + def __new__(mcs, name, bases, class_attributes): + class_attributes_with_columns = mcs.merge_class_attributes( + class_attributes, + mcs.get_column_class_attributes(), + ) + + return super().__new__( + mcs, + name, + bases, + mcs.merge_class_attributes( + class_attributes_with_columns, + { + "unique_names": mcs.get_unique_class_attribute_values( + class_attributes_with_columns + ) + }, + ), + ) + + @staticmethod + def get_unique_class_attribute_values( + class_attributes: dict[str, Any], + ) -> set[Any]: + """ + Return class attribute values. + + Raises an error if attribute values are not unique. + """ + + unique_attribute_values = set() + for attribute_name, attribute_value in class_attributes.items(): + # Skip special attributes, i.e. "____" + if attribute_name.startswith("__") and attribute_name.endswith( + "__" + ): + continue + + if attribute_value in unique_attribute_values: + raise ValueError( + f"Duplicate unique global name {attribute_value}" + ) + unique_attribute_values.add(attribute_value) + + return unique_attribute_values + + @staticmethod + def merge_class_attributes( + class_attributes1: dict[str, Any], + class_attributes2: dict[str, Any], + ) -> dict[str, Any]: + """ + Merges two class attribute dictionaries. + + Raise an error if both dictionaries have an attribute + with the same name. + """ + + for attribute_name in class_attributes2: + if attribute_name in class_attributes1: + raise ValueError(f"Duplicate class attribute {attribute_name}") + + return dict(**class_attributes1, **class_attributes2) + + @staticmethod + def get_column_class_attributes() -> dict[str, str]: + """Automatically generates global names for each column type.""" + + import piccolo.columns.column_types + + class_attributes: dict[str, str] = {} + for module_global in piccolo.columns.column_types.__dict__.values(): + try: + if module_global is not Column and issubclass( + module_global, Column + ): + class_attributes[ + f"COLUMN_{module_global.__name__.upper()}" + ] = module_global.__name__ + except TypeError: + pass + + return class_attributes + + +class UniqueGlobalNames(metaclass=UniqueGlobalNamesMeta): + """ + Contains global names that may be used during serialisation. + + The global names are stored as class attributes. Names that may + occur in the global namespace after serialisation should be listed here. + + This class is meant to prevent against the use of conflicting global + names. If possible imports and global definitions should use this + class to ensure that no conflicts arise during serialisation. + """ + + # Piccolo imports + TABLE = Table.__name__ + DEFAULT = Default.__name__ + # Column types are omitted because they are added by metaclass + + # Standard library imports + STD_LIB_ENUM = Enum.__name__ + STD_LIB_MODULE_DECIMAL = "decimal" + + # Third-party library imports + EXTERNAL_MODULE_UUID = "uuid" + EXTERNAL_UUID = f"{EXTERNAL_MODULE_UUID}.{uuid.UUID.__name__}" + + # This attribute is set in metaclass + unique_names: set[str] + + @classmethod + def warn_if_is_conflicting_name( + cls, name: str, name_type: str = "Name" + ) -> None: + """Raise an error if ``name`` matches a class attribute value.""" + + if cls.is_conflicting_name(name): + warnings.warn( + f"{name_type} '{name}' could conflict with global name", + UniqueGlobalNameConflictWarning, + ) + + @classmethod + def is_conflicting_name(cls, name: str) -> bool: + """Check if ``name`` matches a class attribute value.""" + + return name in cls.unique_names + + @staticmethod + def warn_if_are_conflicting_objects( + objects: Iterable[CanConflictWithGlobalNames], + ) -> None: + """ + Call each object's ``raise_if_is_conflicting_with_global_name`` method. + """ + + for obj in objects: + obj.warn_if_is_conflicting_with_global_name() + + +class UniqueGlobalNameConflictWarning(UserWarning): + pass + + +############################################################################### + + +@dataclass +class Import(CanConflictWithGlobalNames): + module: str + target: Optional[str] = None + expect_conflict_with_global_name: Optional[str] = None + + def __post_init__(self) -> None: + if ( + self.expect_conflict_with_global_name is not None + and not UniqueGlobalNames.is_conflicting_name( + self.expect_conflict_with_global_name + ) + ): + raise ValueError( + f"`expect_conflict_with_global_name=" + f'"{self.expect_conflict_with_global_name}"` ' + f"is not listed in `{UniqueGlobalNames.__name__}`" + ) + + def __repr__(self): + if self.target is None: + return f"import {self.module}" + + return f"from {self.module} import {self.target}" + + def __hash__(self): + if self.target is None: + return hash(f"{self.module}") + + return hash(f"{self.module}-{self.target}") + + def __lt__(self, other): + return repr(self) < repr(other) + + def warn_if_is_conflicting_with_global_name(self): + name = self.module if self.target is None else self.target + if name == self.expect_conflict_with_global_name: + return + + if UniqueGlobalNames.is_conflicting_name(name): + UniqueGlobalNames.warn_if_is_conflicting_name( + name, name_type="Import" + ) + + +class Definition(CanConflictWithGlobalNames, abc.ABC): + @abc.abstractmethod + def __repr__(self): ... + + ########################################################################### + # To allow sorting: + + def __lt__(self, value): + return self.__repr__() < value.__repr__() + + def __le__(self, value): + return self.__repr__() <= value.__repr__() + + def __gt__(self, value): + return self.__repr__() > value.__repr__() + + def __ge__(self, value): + return self.__repr__() >= value.__repr__() + + +@dataclass +class SerialisedParams: + params: dict[str, Any] + extra_imports: list[Import] + extra_definitions: list[Definition] = field(default_factory=list) + + +############################################################################### + + +def check_equality(self, other): + if getattr(other, "__hash__", None) is not None: + return self.__hash__() == other.__hash__() + else: + return False + + @dataclass class SerialisedBuiltin: - builtin: t.Any + builtin: Any def __hash__(self): return hash(self.builtin.__name__) def __eq__(self, other): - return self.__hash__() == other.__hash__() + return check_equality(self, other) def __repr__(self): return self.builtin.__name__ @@ -43,7 +294,7 @@ def __hash__(self): return self.instance.__hash__() def __eq__(self, other): - return self.__hash__() == other.__hash__() + return check_equality(self, other) def __repr__(self): return repr_class_instance(self.instance) @@ -58,15 +309,14 @@ def __hash__(self): return self.instance.__hash__() def __eq__(self, other): - return self.__hash__() == other.__hash__() + return check_equality(self, other) def __repr__(self): args = ", ".join( - [ - f"{key}={self.serialised_params.params.get(key).__repr__()}" # noqa: E501 - for key in self.instance._meta.params.keys() - ] + f"{key}={self.serialised_params.params.get(key).__repr__()}" + for key in self.instance._meta.params.keys() ) + return f"{self.instance.__class__.__name__}({args})" @@ -78,15 +328,29 @@ def __hash__(self): return hash(self.__repr__()) def __eq__(self, other): - return self.__hash__() == other.__hash__() + return check_equality(self, other) def __repr__(self): return f"{self.instance.__class__.__name__}.{self.instance.name}" @dataclass -class SerialisedTableType: - table_type: t.Type[Table] +class SerialisedReference: + name: str + + def __hash__(self): + return hash(self.__repr__()) + + def __eq__(self, other): + return check_equality(self, other) + + def __repr__(self): + return self.name + + +@dataclass +class SerialisedTableType(Definition): + table_type: type[Table] def __hash__(self): return hash( @@ -94,11 +358,10 @@ def __hash__(self): ) def __eq__(self, other): - return self.__hash__() == other.__hash__() + return check_equality(self, other) - def __repr__(self): + def __repr__(self) -> str: tablename = self.table_type._meta.tablename - class_name = self.table_type.__name__ # We have to add the primary key column definition too, so foreign # keys can be created with the correct type. @@ -109,40 +372,118 @@ def __repr__(self): serialised_params=serialise_params(params=pk_column._meta.params), ) - return ( - f'class {class_name}(Table, tablename="{tablename}"): ' + ####################################################################### + + # When creating a ForeignKey, the user can specify a column other than + # the primary key to reference. + serialised_target_columns: set[SerialisedColumnInstance] = set() + + for fk_column in self.table_type._meta._foreign_key_references: + target_column = fk_column._foreign_key_meta.target_column + if target_column is None: + # Just references the primary key + continue + elif type(target_column) is str: + column = self.table_type._meta.get_column_by_name( + target_column + ) + elif isinstance(target_column, Column): + column = self.table_type._meta.get_column_by_name( + target_column._meta.name + ) + else: + raise ValueError("Unrecognised `target_column` value.") + + if column._meta.name == pk_column._meta.name: + # The target column is the foreign key, so no need to add + # it again. + # https://github.com/piccolo-orm/piccolo/issues/1197 + continue + + serialised_target_columns.add( + SerialisedColumnInstance( + column, + serialised_params=serialise_params( + params=column._meta.params + ), + ) + ) + + ####################################################################### + + schema_str = ( + "None" + if self.table_type._meta.schema is None + else f'"{self.table_type._meta.schema}"' + ) + + definition = ( + f"class {self.table_class_name}" + f'({UniqueGlobalNames.TABLE}, tablename="{tablename}", schema={schema_str}): ' # noqa: E501 f"{pk_column_name} = {serialised_pk_column}" ) + for serialised_target_column in serialised_target_columns: + definition += f"; {serialised_target_column.instance._meta.name} = {serialised_target_column}" # noqa: E501 + + return definition + def __lt__(self, other): return repr(self) < repr(other) + @property + def table_class_name(self) -> str: + return self.table_type.__name__ + + def warn_if_is_conflicting_with_global_name(self) -> None: + UniqueGlobalNames.warn_if_is_conflicting_name(self.table_class_name) + @dataclass -class SerialisedEnumType: - enum_type: t.Type[Enum] +class InlineSerialisedEnumType: + enum_type: type[Enum] def __hash__(self): return hash(self.__repr__()) def __eq__(self, other): - return self.__hash__() == other.__hash__() + return check_equality(self, other) def __repr__(self): class_name = self.enum_type.__name__ params = {i.name: i.value for i in self.enum_type} - return f"Enum('{class_name}', {params})" + return f"{UniqueGlobalNames.STD_LIB_ENUM}('{class_name}', {params})" + + +@dataclass +class SerialisedEnumTypeDefinition(Definition): + enum_type: type[Enum] + + def __hash__(self): + return hash(self.enum_type.__name__) + + def __eq__(self, other): + return check_equality(self, other) + + def __repr__(self): + definition = InlineSerialisedEnumType( + enum_type=self.enum_type + ).__repr__() + return f"{self.enum_type.__name__} = {definition}" + + def warn_if_is_conflicting_with_global_name(self) -> None: + UniqueGlobalNames.warn_if_is_conflicting_name(self.enum_type.__name__) @dataclass class SerialisedCallable: - callable_: t.Callable + callable_: Callable def __hash__(self): return hash(self.callable_.__name__) def __eq__(self, other): - return self.__hash__() == other.__hash__() + return check_equality(self, other) def __repr__(self): return self.callable_.__name__ @@ -156,58 +497,74 @@ def __hash__(self): return self.instance.int def __eq__(self, other): - return self.__hash__() == other.__hash__() + return check_equality(self, other) def __repr__(self): - return f"UUID('{str(self.instance)}')" + return f'{UniqueGlobalNames.EXTERNAL_UUID}("{str(self.instance)}")' -############################################################################### +@dataclass +class SerialisedDecimal: + instance: decimal.Decimal + def __hash__(self): + return hash(repr(self)) -@dataclass -class Import: - module: str - target: str + def __eq__(self, other): + return check_equality(self, other) def __repr__(self): - return f"from {self.module} import {self.target}" + return f"{UniqueGlobalNames.STD_LIB_MODULE_DECIMAL}." + repr( + self.instance + ).replace("'", '"') - def __hash__(self): - return hash(f"{self.module}-{self.target}") - def __lt__(self, other): - return repr(self) < repr(other) +############################################################################### -@dataclass -class SerialisedParams: - params: t.Dict[str, t.Any] - extra_imports: t.List[Import] - extra_definitions: t.List[str] = field(default_factory=list) +def serialise_params( + params: dict[str, Any], inline_enums: bool = True +) -> SerialisedParams: + """ + When writing column params to a migration file, or outputting to the + playground, we need to serialise some of the values. + :param inline_enums: + If ``True``, enum value are inlined, for example:: -############################################################################### + value=Enum('MyEnum', {'some_value': 'some_value'})) + Otherwise, it is reproduced as:: + + value=MyEnum + + And the enum definition is added to + ``SerialisedParams.extra_definitions``. -def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams: - """ - When writing column params to a migration file, we need to serialise some - of the values. """ params = deepcopy(params) - extra_imports: t.List[Import] = [] - extra_definitions: t.List[t.Any] = [] + extra_imports: list[Import] = [] + extra_definitions: list[Definition] = [] for key, value in params.items(): - # Builtins, such as str, list and dict. if inspect.getmodule(value) == builtins: params[key] = SerialisedBuiltin(builtin=value) continue - # Column instances, which are used by Array definitions. + # Column instances if isinstance(value, Column): + # For target_column (which is used by ForeignKey), we can just + # serialise it as the column name: + if key == "target_column": + params[key] = value._meta.name + continue + + ################################################################### + + # For Array definitions, we want to serialise the full column + # definition: + column: Column = value serialised_params: SerialisedParams = serialise_params( params=column._meta.params @@ -218,10 +575,16 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams: extra_imports.extend(serialised_params.extra_imports) extra_definitions.extend(serialised_params.extra_definitions) + column_class_name = column.__class__.__name__ extra_imports.append( Import( module=column.__class__.__module__, - target=column.__class__.__name__, + target=column_class_name, + expect_conflict_with_global_name=getattr( + UniqueGlobalNames, + f"COLUMN_{column_class_name.upper()}", + None, + ), ) ) params[key] = SerialisedColumnInstance( @@ -236,6 +599,7 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams: Import( module=value.__class__.__module__, target=value.__class__.__name__, + expect_conflict_with_global_name=UniqueGlobalNames.DEFAULT, ) ) continue @@ -256,13 +620,27 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams: # UUIDs if isinstance(value, uuid.UUID): params[key] = SerialisedUUID(instance=value) - extra_imports.append(Import(module="uuid", target="UUID")) + extra_imports.append( + Import( + module=UniqueGlobalNames.EXTERNAL_MODULE_UUID, + expect_conflict_with_global_name=( + UniqueGlobalNames.EXTERNAL_MODULE_UUID + ), + ) + ) continue # Decimals if isinstance(value, decimal.Decimal): - # Already has a good __repr__. - extra_imports.append(Import(module="decimal", target="Decimal")) + params[key] = SerialisedDecimal(instance=value) + extra_imports.append( + Import( + module=UniqueGlobalNames.STD_LIB_MODULE_DECIMAL, + expect_conflict_with_global_name=( + UniqueGlobalNames.STD_LIB_MODULE_DECIMAL + ), + ) + ) continue # Enum instances @@ -292,8 +670,15 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams: # Enum types if inspect.isclass(value) and issubclass(value, Enum): - params[key] = SerialisedEnumType(enum_type=value) - extra_imports.append(Import(module="enum", target="Enum")) + extra_imports.append( + Import( + module="enum", + target=UniqueGlobalNames.STD_LIB_ENUM, + expect_conflict_with_global_name=( + UniqueGlobalNames.STD_LIB_ENUM + ), + ) + ) for member in value: type_ = type(member.value) module = inspect.getmodule(type_) @@ -304,6 +689,14 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams: Import(module=module_name, target=type_.__name__) ) + if inline_enums: + params[key] = InlineSerialisedEnumType(enum_type=value) + else: + params[key] = SerialisedReference(name=value.__name__) + extra_definitions.append( + SerialisedEnumTypeDefinition(enum_type=value) + ) + # Functions if inspect.isfunction(value): if value.__name__ == "": @@ -324,7 +717,26 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams: SerialisedTableType(table_type=table_type) ) extra_imports.append( - Import(module=Table.__module__, target="Table") + Import( + module=Table.__module__, + target=UniqueGlobalNames.TABLE, + expect_conflict_with_global_name=UniqueGlobalNames.TABLE, + ) + ) + # also add missing primary key to extra_imports when creating a + # migration with a ForeignKey that uses a LazyTableReference + # https://github.com/piccolo-orm/piccolo/issues/865 + primary_key_class = table_type._meta.primary_key.__class__ + extra_imports.append( + Import( + module=primary_key_class.__module__, + target=primary_key_class.__name__, + expect_conflict_with_global_name=getattr( + UniqueGlobalNames, + f"COLUMN_{primary_key_class.__name__.upper()}", + None, + ), + ) ) continue @@ -333,13 +745,23 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams: params[key] = SerialisedCallable(callable_=value) extra_definitions.append(SerialisedTableType(table_type=value)) extra_imports.append( - Import(module=Table.__module__, target="Table") + Import( + module=Table.__module__, + target=UniqueGlobalNames.TABLE, + expect_conflict_with_global_name=UniqueGlobalNames.TABLE, + ) ) + primary_key_class = value._meta.primary_key.__class__ extra_imports.append( Import( - module=value._meta.primary_key.__class__.__module__, - target=value._meta.primary_key.__class__.__name__, + module=primary_key_class.__module__, + target=primary_key_class.__name__, + expect_conflict_with_global_name=getattr( + UniqueGlobalNames, + f"COLUMN_{primary_key_class.__name__.upper()}", + None, + ), ) ) # Include the extra imports and definitions required for the @@ -362,14 +784,20 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams: # All other types can remain as is. + unique_extra_imports = list(set(extra_imports)) + UniqueGlobalNames.warn_if_are_conflicting_objects(unique_extra_imports) + + unique_extra_definitions = list(set(extra_definitions)) + UniqueGlobalNames.warn_if_are_conflicting_objects(unique_extra_definitions) + return SerialisedParams( params=params, - extra_imports=[i for i in set(extra_imports)], - extra_definitions=[i for i in set(extra_definitions)], + extra_imports=unique_extra_imports, + extra_definitions=unique_extra_definitions, ) -def deserialise_params(params: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: +def deserialise_params(params: dict[str, Any]) -> dict[str, Any]: """ When reading column params from a migration file, we need to convert them from their serialised form. @@ -381,15 +809,19 @@ def deserialise_params(params: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: if isinstance(value, str) and not isinstance(value, Enum): if value != "self": params[key] = deserialise_legacy_params(name=key, value=value) + elif isinstance(value, SerialisedColumnInstance): + params[key] = value.instance elif isinstance(value, SerialisedClassInstance): params[key] = value.instance elif isinstance(value, SerialisedUUID): params[key] = value.instance + elif isinstance(value, SerialisedDecimal): + params[key] = value.instance elif isinstance(value, SerialisedCallable): params[key] = value.callable_ elif isinstance(value, SerialisedTableType): params[key] = value.table_type - elif isinstance(value, SerialisedEnumType): + elif isinstance(value, InlineSerialisedEnumType): params[key] = value.enum_type elif isinstance(value, SerialisedEnumInstance): params[key] = value.instance diff --git a/piccolo/apps/migrations/auto/serialisation_legacy.py b/piccolo/apps/migrations/auto/serialisation_legacy.py index 3bec6ce2d..7ccfbf740 100644 --- a/piccolo/apps/migrations/auto/serialisation_legacy.py +++ b/piccolo/apps/migrations/auto/serialisation_legacy.py @@ -1,14 +1,14 @@ from __future__ import annotations import datetime -import typing as t +from typing import Any from piccolo.columns.column_types import OnDelete, OnUpdate from piccolo.columns.defaults.timestamp import TimestampNow -from piccolo.table import Table, create_table_class +from piccolo.table import create_table_class -def deserialise_legacy_params(name: str, value: str) -> t.Any: +def deserialise_legacy_params(name: str, value: str) -> Any: """ Earlier versions of Piccolo serialised parameters differently. This is here purely for backwards compatibility. @@ -26,37 +26,31 @@ def deserialise_legacy_params(name: str, value: str) -> t.Any: "`SomeClassName` or `SomeClassName|some_table_name`." ) - _Table: t.Type[Table] = create_table_class( + return create_table_class( class_name=class_name, class_kwargs={"tablename": tablename} if tablename else {}, ) - return _Table ########################################################################### - if name == "on_delete": + if name == "default": + if value in {"TimestampDefault.now", "DatetimeDefault.now"}: + return TimestampNow() + try: + _value = datetime.datetime.fromisoformat(value) + except ValueError: + pass + else: + return _value + + elif name == "on_delete": enum_name, item_name = value.split(".") if enum_name == "OnDelete": return getattr(OnDelete, item_name) - ########################################################################### - - if name == "on_update": + elif name == "on_update": enum_name, item_name = value.split(".") if enum_name == "OnUpdate": return getattr(OnUpdate, item_name) - ########################################################################### - - if name == "default": - if value in ("TimestampDefault.now", "DatetimeDefault.now"): - return TimestampNow() - else: - try: - _value = datetime.datetime.fromisoformat(value) - except ValueError: - pass - else: - return _value - return value diff --git a/piccolo/apps/migrations/commands/backwards.py b/piccolo/apps/migrations/commands/backwards.py index acec79b3e..25b72e8e3 100644 --- a/piccolo/apps/migrations/commands/backwards.py +++ b/piccolo/apps/migrations/commands/backwards.py @@ -3,12 +3,14 @@ import os import sys -from piccolo.apps.migrations.auto import MigrationManager +from piccolo.apps.migrations.auto.migration_manager import MigrationManager from piccolo.apps.migrations.commands.base import ( BaseMigrationManager, MigrationResult, ) from piccolo.apps.migrations.tables import Migration +from piccolo.conf.apps import AppConfig, MigrationModule +from piccolo.utils.printing import print_heading class BackwardsMigrationManager(BaseMigrationManager): @@ -18,27 +20,21 @@ def __init__( migration_id: str, auto_agree: bool = False, clean: bool = False, + preview: bool = False, ): self.migration_id = migration_id self.app_name = app_name self.auto_agree = auto_agree self.clean = clean + self.preview = preview super().__init__() - async def run(self) -> MigrationResult: - await self.create_migration_table() - - app_modules = self.get_app_modules() - - migration_modules = {} - - for app_module in app_modules: - app_config = getattr(app_module, "APP_CONFIG") - if app_config.app_name == self.app_name: - migration_modules = self.get_migration_modules( - app_config.migrations_folder_path - ) - break + async def run_migrations_backwards(self, app_config: AppConfig): + migration_modules: dict[str, MigrationModule] = ( + self.get_migration_modules( + app_config.resolved_migrations_folder_path + ) + ) ran_migration_ids = await Migration.get_migrations_which_ran( app_name=self.app_name @@ -46,7 +42,7 @@ async def run(self) -> MigrationResult: if len(ran_migration_ids) == 0: # Make sure a success is returned, as we don't want this # to appear as an error in automated scripts. - message = "No migrations to reverse!" + message = "🏁 No migrations to reverse!" print(message) return MigrationResult(success=True, message=message) @@ -79,33 +75,32 @@ async def run(self) -> MigrationResult: ####################################################################### + n = len(reversed_migration_ids) _continue = ( "y" if self.auto_agree else input( - "About to undo the following migrations:\n" - f"{reversed_migration_ids}\n" - "Enter y to continue.\n" - ) + f"Reverse {n} migration{'s' if n != 1 else ''}? [y/N] " + ).lower() ) if _continue == "y": - print("Undoing migrations") - for migration_id in reversed_migration_ids: - print(f"Reversing {migration_id}") migration_module = migration_modules[migration_id] response = await migration_module.forwards() if isinstance(response, MigrationManager): - await response.run_backwards() - - await Migration.delete().where( - Migration.name == migration_id - ).run() - - if self.clean: - os.unlink(migration_module.__file__) - + if self.preview: + response.preview = True + await response.run(backwards=True) + if not self.preview: + await Migration.delete().where( + Migration.name == migration_id + ).run() + + if self.clean and migration_module.__file__: + os.unlink(migration_module.__file__) + + print("ok! ✔️") return MigrationResult(success=True) else: # pragma: no cover @@ -113,45 +108,54 @@ async def run(self) -> MigrationResult: print(message, file=sys.stderr) return MigrationResult(success=False, message=message) + async def run(self) -> MigrationResult: + await self.create_migration_table() + app_config = self.get_app_config(self.app_name) + return await self.run_migrations_backwards(app_config=app_config) + async def run_backwards( app_name: str, migration_id: str = "1", auto_agree: bool = False, clean: bool = False, + preview: bool = False, ) -> MigrationResult: if app_name == "all": sorted_app_names = BaseMigrationManager().get_sorted_app_names() sorted_app_names.reverse() + names = [f"'{name}'" for name in sorted_app_names] _continue = ( "y" if auto_agree else input( - "You're about to undo the migrations for the following apps:\n" - f"{sorted_app_names}\n" - "Are you sure you want to continue?\n" - "Enter y to continue.\n" - ) + "You are about to undo the migrations for the following " + "apps:\n" + f"{', '.join(names)}\n" + "Are you sure you want to continue? [y/N] " + ).lower() ) - if _continue == "y": - for _app_name in sorted_app_names: - print(f"Undoing {_app_name}") - manager = BackwardsMigrationManager( - app_name=_app_name, - migration_id="all", - auto_agree=auto_agree, - ) - await manager.run() - return MigrationResult(success=True) - else: - return MigrationResult(success=False, message="User cancelled") + + if _continue != "y": + return MigrationResult(success=False, message="user cancelled") + for _app_name in sorted_app_names: + print_heading(_app_name) + manager = BackwardsMigrationManager( + app_name=_app_name, + migration_id="all", + auto_agree=auto_agree, + preview=preview, + ) + await manager.run() + return MigrationResult(success=True) else: manager = BackwardsMigrationManager( app_name=app_name, migration_id=migration_id, auto_agree=auto_agree, clean=clean, + preview=preview, ) return await manager.run() @@ -161,6 +165,7 @@ async def backwards( migration_id: str = "1", auto_agree: bool = False, clean: bool = False, + preview: bool = False, ): """ Undo migrations up to a specific migration. @@ -177,6 +182,8 @@ async def backwards( :param clean: If true, the migration files which have been run backwards are deleted from the disk after completing. + :param preview: + If true, don't actually run the migration, just print the SQL queries. """ response = await run_backwards( @@ -184,6 +191,7 @@ async def backwards( migration_id=migration_id, auto_agree=auto_agree, clean=clean, + preview=preview, ) if not response.success: diff --git a/piccolo/apps/migrations/commands/base.py b/piccolo/apps/migrations/commands/base.py index dcf6734e4..5a4e9615d 100644 --- a/piccolo/apps/migrations/commands/base.py +++ b/piccolo/apps/migrations/commands/base.py @@ -3,8 +3,8 @@ import importlib import os import sys -import typing as t from dataclasses import dataclass +from typing import Optional, cast from piccolo.apps.migrations.auto.diffable_table import DiffableTable from piccolo.apps.migrations.auto.migration_manager import MigrationManager @@ -16,7 +16,7 @@ @dataclass class MigrationResult: success: bool - message: t.Optional[str] = None + message: Optional[str] = None class BaseMigrationManager(Finder): @@ -32,7 +32,7 @@ async def create_migration_table(self) -> bool: def get_migration_modules( self, folder_path: str - ) -> t.Dict[str, MigrationModule]: + ) -> dict[str, MigrationModule]: """ Imports the migration modules in the given folder path, and returns a mapping of migration ID to the corresponding migration module. @@ -50,8 +50,8 @@ def get_migration_modules( if ((i not in excluded) and i.endswith(".py")) ] - modules: t.List[MigrationModule] = [ - t.cast(MigrationModule, importlib.import_module(name)) + modules: list[MigrationModule] = [ + cast(MigrationModule, importlib.import_module(name)) for name in migration_names ] for m in modules: @@ -62,19 +62,19 @@ def get_migration_modules( return migration_modules def get_migration_ids( - self, migration_module_dict: t.Dict[str, MigrationModule] - ) -> t.List[str]: + self, migration_module_dict: dict[str, MigrationModule] + ) -> list[str]: """ Returns a list of migration IDs, from the Python migration files. """ - return sorted(list(migration_module_dict.keys())) + return sorted(migration_module_dict.keys()) async def get_migration_managers( self, app_config: AppConfig, - max_migration_id: t.Optional[str] = None, + max_migration_id: Optional[str] = None, offset: int = 0, - ) -> t.List[MigrationManager]: + ) -> list[MigrationManager]: """ Call the forwards coroutine in each migration module. Each one should return a `MigrationManger`. Combine all of the results, and return in @@ -84,13 +84,13 @@ async def get_migration_managers( If set, only MigrationManagers up to and including the given migration ID will be returned. """ - migration_managers: t.List[MigrationManager] = [] + migration_managers: list[MigrationManager] = [] - migrations_folder = app_config.migrations_folder_path + migrations_folder = app_config.resolved_migrations_folder_path - migration_modules: t.Dict[ - str, MigrationModule - ] = self.get_migration_modules(migrations_folder) + migration_modules: dict[str, MigrationModule] = ( + self.get_migration_modules(migrations_folder) + ) migration_ids = sorted(migration_modules.keys()) @@ -110,15 +110,15 @@ async def get_migration_managers( "Positive offset values aren't currently supported" ) elif offset < 0: - return migration_managers[0:offset] + return migration_managers[:offset] else: return migration_managers - async def get_table_from_snaphot( + async def get_table_from_snapshot( self, app_name: str, table_class_name: str, - max_migration_id: t.Optional[str] = None, + max_migration_id: Optional[str] = None, offset: int = 0, ) -> DiffableTable: """ diff --git a/piccolo/apps/migrations/commands/check.py b/piccolo/apps/migrations/commands/check.py index c2377c002..ee4bf12a3 100644 --- a/piccolo/apps/migrations/commands/check.py +++ b/piccolo/apps/migrations/commands/check.py @@ -1,5 +1,4 @@ import dataclasses -import typing as t from piccolo.apps.migrations.commands.base import BaseMigrationManager from piccolo.apps.migrations.tables import Migration @@ -19,11 +18,11 @@ def __init__(self, app_name: str): self.app_name = app_name super().__init__() - async def get_migration_statuses(self) -> t.List[MigrationStatus]: + async def get_migration_statuses(self) -> list[MigrationStatus]: # Make sure the migration table exists, otherwise we'll get an error. await self.create_migration_table() - migration_statuses: t.List[MigrationStatus] = [] + migration_statuses: list[MigrationStatus] = [] app_modules = self.get_app_modules() @@ -32,11 +31,11 @@ async def get_migration_statuses(self) -> t.List[MigrationStatus]: app_name = app_config.app_name - if (self.app_name != "all") and (self.app_name != app_name): + if self.app_name not in ["all", app_name]: continue migration_modules = self.get_migration_modules( - app_config.migrations_folder_path + app_config.resolved_migrations_folder_path ) ids = self.get_migration_ids(migration_modules) for _id in ids: diff --git a/piccolo/apps/migrations/commands/clean.py b/piccolo/apps/migrations/commands/clean.py index e2fa7ca6b..5a95015a0 100644 --- a/piccolo/apps/migrations/commands/clean.py +++ b/piccolo/apps/migrations/commands/clean.py @@ -1,6 +1,6 @@ from __future__ import annotations -import typing as t +from typing import cast from piccolo.apps.migrations.commands.base import BaseMigrationManager from piccolo.apps.migrations.tables import Migration @@ -12,7 +12,7 @@ def __init__(self, app_name: str, auto_agree: bool = False): self.auto_agree = auto_agree super().__init__() - def get_migration_ids_to_remove(self) -> t.List[str]: + def get_migration_ids_to_remove(self) -> list[str]: """ Returns a list of migration ID strings, which are rows in the table, but don't have a corresponding migration module on disk. @@ -20,7 +20,7 @@ def get_migration_ids_to_remove(self) -> t.List[str]: app_config = self.get_app_config(app_name=self.app_name) migration_module_dict = self.get_migration_modules( - folder_path=app_config.migrations_folder_path + folder_path=app_config.resolved_migrations_folder_path ) # The migration IDs which are in migration modules. @@ -37,8 +37,7 @@ def get_migration_ids_to_remove(self) -> t.List[str]: if len(migration_ids) > 0: query = query.where(Migration.name.not_in(migration_ids)) - migration_ids_to_remove = query.run_sync() - return migration_ids_to_remove + return cast(list[str], query.run_sync()) async def run(self): print("Checking the migration table ...") diff --git a/piccolo/apps/migrations/commands/forwards.py b/piccolo/apps/migrations/commands/forwards.py index 5510cb691..adc7e657d 100644 --- a/piccolo/apps/migrations/commands/forwards.py +++ b/piccolo/apps/migrations/commands/forwards.py @@ -1,24 +1,29 @@ from __future__ import annotations import sys -import typing as t -from piccolo.apps.migrations.auto import MigrationManager +from piccolo.apps.migrations.auto.migration_manager import MigrationManager from piccolo.apps.migrations.commands.base import ( BaseMigrationManager, MigrationResult, ) from piccolo.apps.migrations.tables import Migration from piccolo.conf.apps import AppConfig, MigrationModule +from piccolo.utils.printing import print_heading class ForwardsMigrationManager(BaseMigrationManager): def __init__( - self, app_name: str, migration_id: str = "all", fake: bool = False + self, + app_name: str, + migration_id: str = "all", + fake: bool = False, + preview: bool = False, ): self.app_name = app_name self.migration_id = migration_id self.fake = fake + self.preview = preview super().__init__() async def run_migrations(self, app_config: AppConfig) -> MigrationResult: @@ -26,22 +31,26 @@ async def run_migrations(self, app_config: AppConfig) -> MigrationResult: app_name=app_config.app_name ) - migration_modules: t.Dict[ - str, MigrationModule - ] = self.get_migration_modules(app_config.migrations_folder_path) + migration_modules: dict[str, MigrationModule] = ( + self.get_migration_modules( + app_config.resolved_migrations_folder_path + ) + ) ids = self.get_migration_ids(migration_modules) - print(f"All migration ids = {ids}") + n = len(ids) + print(f"👍 {n} migration{'s' if n != 1 else ''} already complete") havent_run = sorted(set(ids) - set(already_ran)) - print(f"Haven't run = {havent_run}") - if len(havent_run) == 0: # Make sure this still appears successful, as we don't want this # to appear as an error in automated scripts. - message = "No migrations left to run!" + message = "🏁 No migrations need to be run" print(message) return MigrationResult(success=True, message=message) + else: + n = len(havent_run) + print(f"⏩ {n} migration{'s' if n != 1 else ''} not yet run") if self.migration_id == "all": subset = havent_run @@ -57,26 +66,32 @@ async def run_migrations(self, app_config: AppConfig) -> MigrationResult: else: subset = havent_run[: index + 1] - for _id in subset: - if self.fake: - print(f"Faked {_id}") - else: + if subset: + n = len(subset) + print(f"🚀 Running {n} migration{'s' if n != 1 else ''}:") + + for _id in subset: migration_module = migration_modules[_id] response = await migration_module.forwards() if isinstance(response, MigrationManager): - await response.run() + if self.fake or response.fake: + print(f"- {_id}: faked! ⏭️") + else: + if self.preview: + response.preview = True + await response.run() - print(f"-> Ran {_id}") + print("ok! ✔️") - await Migration.insert().add( - Migration(name=_id, app_name=app_config.app_name) - ).run() + if not self.preview: + await Migration.insert().add( + Migration(name=_id, app_name=app_config.app_name) + ).run() - return MigrationResult(success=True, message="Ran successfully") + return MigrationResult(success=True, message="migration succeeded") async def run(self) -> MigrationResult: - print("Running migrations ...") await self.create_migration_table() app_config = self.get_app_config(app_name=self.app_name) @@ -85,7 +100,10 @@ async def run(self) -> MigrationResult: async def run_forwards( - app_name: str, migration_id: str = "all", fake: bool = False + app_name: str, + migration_id: str = "all", + fake: bool = False, + preview: bool = False, ) -> MigrationResult: """ Run the migrations. This function can be used to programatically run @@ -94,10 +112,12 @@ async def run_forwards( if app_name == "all": sorted_app_names = BaseMigrationManager().get_sorted_app_names() for _app_name in sorted_app_names: - print(f"\nMigrating {_app_name}") - print("------------------------------------------------") + print_heading(_app_name) manager = ForwardsMigrationManager( - app_name=_app_name, migration_id="all", fake=fake + app_name=_app_name, + migration_id="all", + fake=fake, + preview=preview, ) response = await manager.run() if not response.success: @@ -107,13 +127,19 @@ async def run_forwards( else: manager = ForwardsMigrationManager( - app_name=app_name, migration_id=migration_id, fake=fake + app_name=app_name, + migration_id=migration_id, + fake=fake, + preview=preview, ) return await manager.run() async def forwards( - app_name: str, migration_id: str = "all", fake: bool = False + app_name: str, + migration_id: str = "all", + fake: bool = False, + preview: bool = False, ): """ Runs any migrations which haven't been run yet. @@ -128,9 +154,15 @@ async def forwards( :param fake: If set, will record the migrations as being run without actually running them. + :param preview: + If true, don't actually run the migration, just print the SQL queries + """ response = await run_forwards( - app_name=app_name, migration_id=migration_id, fake=fake + app_name=app_name, + migration_id=migration_id, + fake=fake, + preview=preview, ) if not response.success: diff --git a/piccolo/apps/migrations/commands/new.py b/piccolo/apps/migrations/commands/new.py index f4d3fdb1a..172de96ed 100644 --- a/piccolo/apps/migrations/commands/new.py +++ b/piccolo/apps/migrations/commands/new.py @@ -2,13 +2,13 @@ import datetime import os -import sys -import typing as t +import string from dataclasses import dataclass from itertools import chain from types import ModuleType +from typing import Optional -import black # type: ignore +import black import jinja2 from piccolo import __VERSION__ @@ -20,6 +20,8 @@ ) from piccolo.conf.apps import AppConfig, Finder from piccolo.engine import SQLiteEngine +from piccolo.utils.printing import print_heading +from piccolo.utils.warnings import colored_warning from .base import BaseMigrationManager @@ -31,8 +33,9 @@ loader=jinja2.FileSystemLoader(searchpath=TEMPLATE_DIRECTORY), ) +MIGRATION_MODULES: dict[str, ModuleType] = {} -MIGRATION_MODULES: t.Dict[str, ModuleType] = {} +VALID_PYTHON_MODULE_CHARACTERS = string.ascii_lowercase + string.digits + "_" def render_template(**kwargs): @@ -47,11 +50,10 @@ def _create_migrations_folder(migrations_path: str) -> bool: """ if os.path.exists(migrations_path): return False - else: - os.mkdir(migrations_path) - with open(os.path.join(migrations_path, "__init__.py"), "w"): - pass - return True + os.mkdir(migrations_path) + with open(os.path.join(migrations_path, "__init__.py"), "w"): + pass + return True @dataclass @@ -61,6 +63,13 @@ class NewMigrationMeta: migration_path: str +def now(): + """ + In a separate function so it's easier to patch in tests. + """ + return datetime.datetime.now() + + def _generate_migration_meta(app_config: AppConfig) -> NewMigrationMeta: """ Generates the migration ID and filename. @@ -69,15 +78,29 @@ def _generate_migration_meta(app_config: AppConfig) -> NewMigrationMeta: # chance that the IDs would clash if the migrations are generated # programatically in quick succession (e.g. in a unit test), so they had # to be added. The trade off is a longer ID. - _id = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S:%f") + _id = now().strftime("%Y-%m-%dT%H:%M:%S:%f") # Originally we just used the _id as the filename, but colons aren't # supported in Windows, so we need to sanitize it. We don't want to # change the _id format though, as it would break existing migrations. # The filename doesn't have any special significance - only the id matters. - filename = _id.replace(":", "-") + cleaned_id = _id.replace(":", "_").replace("-", "_").lower() + + # Just in case the app name contains characters which aren't valid for + # a Python module. + cleaned_app_name = "".join( + [ + i + for i in app_config.app_name.lower().replace("-", "_") + if i in VALID_PYTHON_MODULE_CHARACTERS + ] + ) + + filename = f"{cleaned_app_name}_{cleaned_id}" - path = os.path.join(app_config.migrations_folder_path, f"{filename}.py") + path = os.path.join( + app_config.resolved_migrations_folder_path, f"{filename}.py" + ) return NewMigrationMeta( migration_id=_id, migration_filename=filename, migration_path=path @@ -89,7 +112,10 @@ class NoChanges(Exception): async def _create_new_migration( - app_config: AppConfig, auto: bool = False, description: str = "" + app_config: AppConfig, + auto: bool = False, + description: str = "", + auto_input: Optional[str] = None, ) -> NewMigrationMeta: """ Creates a new migration file on disk. @@ -97,22 +123,22 @@ async def _create_new_migration( meta = _generate_migration_meta(app_config=app_config) if auto: - alter_statements = await AutoMigrationManager().get_alter_statements( - app_config=app_config - ) + alter_statements = await AutoMigrationManager( + auto_input=auto_input + ).get_alter_statements(app_config=app_config) _alter_statements = list( chain(*[i.statements for i in alter_statements]) ) extra_imports = sorted( - list(set(chain(*[i.extra_imports for i in alter_statements]))), + set(chain(*[i.extra_imports for i in alter_statements])), key=lambda x: x.__repr__(), ) extra_definitions = sorted( - list(set(chain(*[i.extra_definitions for i in alter_statements]))), + set(chain(*[i.extra_definitions for i in alter_statements])), ) - if sum([len(i.statements) for i in alter_statements]) == 0: + if sum(len(i.statements) for i in alter_statements) == 0: raise NoChanges() file_contents = render_template( @@ -144,9 +170,13 @@ async def _create_new_migration( class AutoMigrationManager(BaseMigrationManager): + def __init__(self, auto_input: Optional[str] = None, *args, **kwargs): + self.auto_input = auto_input + super().__init__(*args, **kwargs) + async def get_alter_statements( self, app_config: AppConfig - ) -> t.List[AlterStatements]: + ) -> list[AlterStatements]: """ Works out which alter statements are required. """ @@ -163,48 +193,80 @@ async def get_alter_statements( class_name=i.__name__, tablename=i._meta.tablename, columns=i._meta.non_default_columns, + schema=i._meta.schema, ) for i in app_config.table_classes ] # Compare the current schema with the snapshot differ = SchemaDiffer( - schema=current_diffable_tables, schema_snapshot=snapshot + schema=current_diffable_tables, + schema_snapshot=snapshot, + auto_input=self.auto_input, ) - alter_statements = differ.get_alter_statements() - - return alter_statements + return differ.get_alter_statements() ############################################################################### -async def new(app_name: str, auto: bool = False, desc: str = ""): +async def new( + app_name: str, + auto: bool = False, + desc: str = "", + auto_input: Optional[str] = None, +): """ Creates a new migration file in the migrations folder. :param app_name: - The app to create a migration for. + The app to create a migration for. Specify a value of 'all' to create + migrations for all apps (use in conjunction with --auto). :param auto: Auto create the migration contents. :param desc: - A description of what the migration does, for example 'adding name - column'. + A description of what the migration does, for example --desc='adding + name column'. + :param auto_input: + If provided, all prompts for user input will automatically have this + entered. For example, --auto_input='y'. """ - print("Creating new migration ...") - engine = Finder().get_engine() if auto and isinstance(engine, SQLiteEngine): - sys.exit("Auto migrations aren't currently supported by SQLite.") + colored_warning("Auto migrations aren't fully supported by SQLite.") - app_config = Finder().get_app_config(app_name=app_name) + if app_name == "all" and not auto: + raise ValueError( + "Only use `--app_name=all` in conjunction with `--auto`." + ) - _create_migrations_folder(app_config.migrations_folder_path) - try: - await _create_new_migration( - app_config=app_config, auto=auto, description=desc + app_names = ( + sorted( + BaseMigrationManager().get_app_names( + sort_by_migration_dependencies=False + ) ) - except NoChanges: - print("No changes detected - exiting.") - sys.exit(0) + if app_name == "all" + else [app_name] + ) + + for app_name in app_names: + print_heading(app_name) + print("🚀 Creating new migration ...") + + app_config = Finder().get_app_config(app_name=app_name) + + _create_migrations_folder(app_config.resolved_migrations_folder_path) + + try: + await _create_new_migration( + app_config=app_config, + auto=auto, + description=desc, + auto_input=auto_input, + ) + except NoChanges: + print("🏁 No changes detected.") + + print("\n✅ Finished\n") diff --git a/piccolo/apps/migrations/commands/templates/migration.py.jinja b/piccolo/apps/migrations/commands/templates/migration.py.jinja index 90331d9b4..70cb3b793 100644 --- a/piccolo/apps/migrations/commands/templates/migration.py.jinja +++ b/piccolo/apps/migrations/commands/templates/migration.py.jinja @@ -1,4 +1,4 @@ -from piccolo.apps.migrations.auto import MigrationManager +from piccolo.apps.migrations.auto.migration_manager import MigrationManager {% for extra_import in extra_imports -%} {{ extra_import }} {% endfor %} diff --git a/piccolo/apps/migrations/tables.py b/piccolo/apps/migrations/tables.py index 91906be4f..782172bbb 100644 --- a/piccolo/apps/migrations/tables.py +++ b/piccolo/apps/migrations/tables.py @@ -1,6 +1,6 @@ from __future__ import annotations -import typing as t +from typing import Optional from piccolo.columns import Timestamp, Varchar from piccolo.columns.defaults.timestamp import TimestampNow @@ -14,8 +14,8 @@ class Migration(Table): @classmethod async def get_migrations_which_ran( - cls, app_name: t.Optional[str] = None - ) -> t.List[str]: + cls, app_name: Optional[str] = None + ) -> list[str]: """ Returns the names of migrations which have already run, by inspecting the database. diff --git a/piccolo/apps/playground/commands/run.py b/piccolo/apps/playground/commands/run.py index e8fe044a3..670dcd664 100644 --- a/piccolo/apps/playground/commands/run.py +++ b/piccolo/apps/playground/commands/run.py @@ -2,66 +2,201 @@ Populates a database with an example schema and data, and launches a shell for interacting with the data using Piccolo. """ + import datetime import sys import uuid from decimal import Decimal +from enum import Enum +from typing import Optional from piccolo.columns import ( JSON, + M2M, UUID, + Array, Boolean, + Date, ForeignKey, Integer, Interval, + LazyTableReference, Numeric, + Serial, + Text, Timestamp, Varchar, ) -from piccolo.engine import PostgresEngine, SQLiteEngine +from piccolo.columns.readable import Readable +from piccolo.engine import CockroachEngine, PostgresEngine, SQLiteEngine from piccolo.engine.base import Engine from piccolo.table import Table +from piccolo.utils.warnings import colored_string class Manager(Table): + id: Serial name = Varchar(length=50) + @classmethod + def get_readable(cls) -> Readable: + return Readable( + template="%s", + columns=[cls.name], + ) + class Band(Table): + id: Serial name = Varchar(length=50) manager = ForeignKey(references=Manager, null=True) popularity = Integer() + genres = M2M(LazyTableReference("GenreToBand", module_path=__name__)) + + @classmethod + def get_readable(cls) -> Readable: + return Readable( + template="%s", + columns=[cls.name], + ) + + +class FanClub(Table): + id: Serial + address = Text() + band = ForeignKey(Band, unique=True) class Venue(Table): + id: Serial name = Varchar(length=100) capacity = Integer(default=0) + @classmethod + def get_readable(cls) -> Readable: + return Readable( + template="%s", + columns=[cls.name], + ) + class Concert(Table): + id: Serial band_1 = ForeignKey(Band) band_2 = ForeignKey(Band) venue = ForeignKey(Venue) starts = Timestamp() duration = Interval() + @classmethod + def get_readable(cls) -> Readable: + return Readable( + template="%s and %s at %s", + columns=[ + cls.band_1.name, + cls.band_2.name, + cls.venue.name, + ], + ) + class Ticket(Table): + class TicketType(Enum): + sitting = "sitting" + standing = "standing" + premium = "premium" + + id: Serial concert = ForeignKey(Concert) price = Numeric(digits=(5, 2)) + ticket_type = Varchar(choices=TicketType, default=TicketType.standing) + + @classmethod + def get_readable(cls) -> Readable: + return Readable( + template="%s - %s", + columns=[ + cls.concert._.venue._.name, + cls.ticket_type, + ], + ) class DiscountCode(Table): + id: Serial code = UUID() active = Boolean(default=True, null=True) + @classmethod + def get_readable(cls) -> Readable: + return Readable( + template="%s - %s", + columns=[cls.code, cls.active], + ) + class RecordingStudio(Table): + id: Serial name = Varchar(length=100) - facilities = JSON() + facilities = JSON(null=True) + + @classmethod + def get_readable(cls) -> Readable: + return Readable( + template="%s", + columns=[cls.name], + ) + +class Album(Table): + id: Serial + name = Varchar() + band = ForeignKey(Band) + release_date = Date() + recorded_at = ForeignKey(RecordingStudio) + awards = Array(Varchar()) -TABLES = (Manager, Band, Venue, Concert, Ticket, DiscountCode, RecordingStudio) + @classmethod + def get_readable(cls) -> Readable: + return Readable( + template="%s - %s", + columns=[cls.name, cls.band._.name], + ) + + +class Genre(Table): + id: Serial + name = Varchar() + bands = M2M(LazyTableReference("GenreToBand", module_path=__name__)) + + @classmethod + def get_readable(cls) -> Readable: + return Readable( + template="%s", + columns=[cls.name], + ) + + +class GenreToBand(Table): + id: Serial + band = ForeignKey(Band) + genre = ForeignKey(Genre) + reason = Text(null=True, default=None) + + +TABLES = ( + Manager, + Band, + FanClub, + Venue, + Concert, + Ticket, + DiscountCode, + RecordingStudio, + Album, + Genre, + GenreToBand, +) def populate(): @@ -70,8 +205,7 @@ def populate(): """ for _table in reversed(TABLES): try: - if _table.table_exists().run_sync(): - _table.alter().drop_table().run_sync() + _table.alter().drop_table(if_exists=True).run_sync() except Exception as e: print(e) @@ -87,12 +221,21 @@ def populate(): pythonistas = Band(name="Pythonistas", manager=guido.id, popularity=1000) pythonistas.save().run_sync() + fan_club = FanClub(address="1 Flying Circus, UK", band=pythonistas) + fan_club.save().run_sync() + graydon = Manager(name="Graydon") graydon.save().run_sync() rustaceans = Band(name="Rustaceans", manager=graydon.id, popularity=500) rustaceans.save().run_sync() + anders = Manager(name="Anders") + anders.save().run_sync() + + c_sharps = Band(name="C-Sharps", popularity=700, manager=anders.id) + c_sharps.save().run_sync() + venue = Venue(name="Amazing Venue", capacity=5000) venue.save().run_sync() @@ -108,41 +251,113 @@ def populate(): ticket = Ticket(concert=concert.id, price=Decimal("50.0")) ticket.save().run_sync() - discount_code = DiscountCode(code=uuid.uuid4()) - discount_code.save().run_sync() - - recording_studio = RecordingStudio( - name="Abbey Road", facilities={"restaurant": True, "mixing_desk": True} + DiscountCode.insert( + *[DiscountCode({DiscountCode.code: uuid.uuid4()}) for _ in range(5)] + ).run_sync() + + recording_studio_1 = RecordingStudio( + { + RecordingStudio.name: "Abbey Road", + RecordingStudio.facilities: { + "restaurant": True, + "mixing_desk": True, + "instruments": {"electric_guitars": 10, "drum_kits": 2}, + "technicians": [ + {"name": "Alice Jones"}, + {"name": "Bob Williams"}, + ], + }, + } + ) + recording_studio_1.save().run_sync() + + recording_studio_2 = RecordingStudio( + { + RecordingStudio.name: "Electric Lady", + RecordingStudio.facilities: { + "restaurant": False, + "mixing_desk": True, + "instruments": {"electric_guitars": 6, "drum_kits": 3}, + "technicians": [ + {"name": "Frank Smith"}, + ], + }, + }, ) - recording_studio.save().run_sync() + recording_studio_2.save().run_sync() + + Album.insert( + Album( + { + Album.name: "Awesome album 1", + Album.recorded_at: recording_studio_1, + Album.band: pythonistas, + Album.release_date: datetime.date(year=2021, month=1, day=1), + Album.awards: ["Grammy Award 2021"], + } + ), + Album( + { + Album.name: "Awesome album 2", + Album.recorded_at: recording_studio_2, + Album.band: rustaceans, + Album.release_date: datetime.date(year=2022, month=2, day=2), + Album.awards: ["Mercury Prize 2022"], + } + ), + ).run_sync() + + genres = Genre.insert( + Genre(name="Rock"), + Genre(name="Classical"), + Genre(name="Folk"), + ).run_sync() + + GenreToBand.insert( + GenreToBand( + band=pythonistas.id, + genre=genres[0]["id"], + reason="Because they rock.", + ), + GenreToBand(band=pythonistas.id, genre=genres[2]["id"]), + GenreToBand(band=rustaceans.id, genre=genres[2]["id"]), + GenreToBand(band=c_sharps.id, genre=genres[0]["id"]), + GenreToBand(band=c_sharps.id, genre=genres[1]["id"]), + ).run_sync() def run( engine: str = "sqlite", - user: str = "piccolo", - password: str = "piccolo", + user: Optional[str] = None, + password: Optional[str] = None, database: str = "piccolo_playground", host: str = "localhost", - port: int = 5432, + port: Optional[int] = None, + ipython_profile: bool = False, ): """ Creates a test database to play with. :param engine: - Which database engine to use - options are sqlite or postgres + Which database engine to use - options are sqlite, postgres or + cockroach :param user: - Postgres user + Database user (ignored for SQLite) :param password: - Postgres password + Database password (ignored for SQLite) :param database: - Postgres database + Database name (ignored for SQLite) :param host: - Postgres host + Database host (ignored for SQLite) :param port: - Postgres port + Database port (ignored for SQLite) + :param ipython_profile: + Set to true to use your own IPython profile. Located at ~/.ipython/. + For more info see the IPython docs + https://ipython.readthedocs.io/en/stable/config/intro.html. """ try: - import IPython # type: ignore + import IPython except ImportError: sys.exit( "Install iPython using `pip install 'piccolo[playground,sqlite]'` " @@ -154,41 +369,55 @@ def run( { "host": host, "database": database, - "user": user, - "password": password, - "port": port, + "user": user or "piccolo", + "password": password or "piccolo", + "port": port or 5432, + } + ) + elif engine.upper() == "COCKROACH": + db = CockroachEngine( + { + "host": host, + "database": database, + "user": user or "root", + "password": password or "", + "port": port or 26257, } ) - for _table in TABLES: - _table._meta._db = db else: db = SQLiteEngine() - for _table in TABLES: - _table._meta._db = db + for _table in TABLES: + _table._meta._db = db - print("Tables:\n") + print(colored_string("\nTables:\n")) for _table in TABLES: print(_table._table_str(abbreviated=True)) - print("\n") + print("") - print("Try it as a query builder:") - print("Band.select().run_sync()") - print("Band.select(Band.name).run_sync()") - print("Band.select(Band.name, Band.manager.name).run_sync()") + print(colored_string("Try it as a query builder:")) + print("await Band.select()") + print("await Band.select(Band.name)") + print("await Band.select(Band.name, Band.manager.name)") print("\n") - print("Try it as an ORM:") - print( - "b = Band.objects().where(Band.name == 'Pythonistas').first()." - "run_sync()" - ) + print(colored_string("Try it as an ORM:")) + print("b = await Band.objects().where(Band.name == 'Pythonistas').first()") print("b.popularity = 10000") - print("b.save().run_sync()") + print("await b.save()") print("\n") populate() - from IPython.core.interactiveshell import _asyncio_runner # type: ignore + from IPython.core.async_helpers import _asyncio_runner + + if ipython_profile: + print(colored_string("Using your IPython profile\n")) + # To try this out, set `c.TerminalInteractiveShell.colors = "Linux"` + # in `~/.ipython/profile_default/ipython_config.py` to set the terminal + # color. + conf_args = {} + else: + conf_args = {"colors": "neutral"} - IPython.embed(using=_asyncio_runner) + IPython.embed(using=_asyncio_runner, **conf_args) diff --git a/piccolo/apps/schema/__init__.py b/piccolo/apps/schema/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/piccolo/apps/schema/commands/__init__.py b/piccolo/apps/schema/commands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/piccolo/apps/schema/commands/exceptions.py b/piccolo/apps/schema/commands/exceptions.py new file mode 100644 index 000000000..58a8f5537 --- /dev/null +++ b/piccolo/apps/schema/commands/exceptions.py @@ -0,0 +1,14 @@ +class SchemaCommandError(Exception): + """ + Base class for all schema command errors. + """ + + pass + + +class GenerateError(SchemaCommandError): + """ + Raised when an error occurs during schema generation. + """ + + pass diff --git a/piccolo/apps/schema/commands/generate.py b/piccolo/apps/schema/commands/generate.py new file mode 100644 index 000000000..5e1785784 --- /dev/null +++ b/piccolo/apps/schema/commands/generate.py @@ -0,0 +1,946 @@ +from __future__ import annotations + +import asyncio +import dataclasses +import itertools +import json +import re +import uuid +from datetime import date, datetime +from typing import TYPE_CHECKING, Any, Literal, Optional, Union + +import black + +from piccolo.apps.migrations.auto.serialisation import serialise_params +from piccolo.apps.schema.commands.exceptions import GenerateError +from piccolo.columns import defaults +from piccolo.columns.base import Column, OnDelete, OnUpdate +from piccolo.columns.column_types import ( + JSON, + JSONB, + UUID, + BigInt, + Boolean, + Bytea, + Date, + DoublePrecision, + ForeignKey, + Integer, + Interval, + Numeric, + Real, + Serial, + SmallInt, + Text, + Timestamp, + Timestamptz, + Varchar, +) +from piccolo.columns.defaults.interval import IntervalCustom +from piccolo.columns.indexes import IndexMethod +from piccolo.engine.finder import engine_finder +from piccolo.engine.postgres import PostgresEngine +from piccolo.table import Table, create_table_class, sort_table_classes +from piccolo.utils.naming import _snake_to_camel + +if TYPE_CHECKING: # pragma: no cover + from piccolo.engine.base import Engine + + +class ForeignKeyPlaceholder(Table): + pass + + +@dataclasses.dataclass +class ConstraintTable: + name: str = "" + schema: str = "" + + +@dataclasses.dataclass +class RowMeta: + column_default: str + column_name: str + is_nullable: Literal["YES", "NO"] + table_name: str + character_maximum_length: Optional[int] + data_type: str + numeric_precision: Optional[Union[int, str]] + numeric_scale: Optional[Union[int, str]] + numeric_precision_radix: Optional[Literal[2, 10]] + + @classmethod + def get_column_name_str(cls) -> str: + return ", ".join(i.name for i in dataclasses.fields(cls)) + + +@dataclasses.dataclass +class Constraint: + constraint_type: Literal["PRIMARY KEY", "UNIQUE", "FOREIGN KEY", "CHECK"] + constraint_name: str + constraint_schema: Optional[str] = None + column_name: Optional[str] = None + + +@dataclasses.dataclass +class TableConstraints: + """ + All of the constraints for a certain table in the database. + """ + + tablename: str + constraints: list[Constraint] + + def __post_init__(self) -> None: + foreign_key_constraints: list[Constraint] = [] + unique_constraints: list[Constraint] = [] + primary_key_constraints: list[Constraint] = [] + + for constraint in self.constraints: + if constraint.constraint_type == "FOREIGN KEY": + foreign_key_constraints.append(constraint) + elif constraint.constraint_type == "PRIMARY KEY": + primary_key_constraints.append(constraint) + elif constraint.constraint_type == "UNIQUE": + unique_constraints.append(constraint) + + self.foreign_key_constraints = foreign_key_constraints + self.unique_constraints = unique_constraints + self.primary_key_constraints = primary_key_constraints + + def is_primary_key(self, column_name: str) -> bool: + return any( + i.column_name == column_name for i in self.primary_key_constraints + ) + + def is_unique(self, column_name: str) -> bool: + return any( + i.column_name == column_name for i in self.unique_constraints + ) + + def is_foreign_key(self, column_name: str) -> bool: + return any( + i.column_name == column_name for i in self.foreign_key_constraints + ) + + def get_foreign_key_constraint_name(self, column_name) -> ConstraintTable: + for i in self.foreign_key_constraints: + if i.column_name == column_name: + return ConstraintTable( + name=i.constraint_name, + schema=i.constraint_schema or "public", + ) + + raise ValueError("No matching constraint found") + + +@dataclasses.dataclass +class Trigger: + constraint_name: str + constraint_type: str + table_name: str + column_name: str + on_update: str + on_delete: Literal[ + "NO ACTION", "RESTRICT", "CASCADE", "SET NULL", "SET_DEFAULT" + ] + references_table: str + references_column: str + + +@dataclasses.dataclass +class TableTriggers: + """ + All of the triggers for a certain table in the database. + """ + + tablename: str + triggers: list[Trigger] + + def get_column_triggers(self, column_name: str) -> list[Trigger]: + return [i for i in self.triggers if i.column_name == column_name] + + def get_column_ref_trigger( + self, column_name: str, references_table: str + ) -> Optional[Trigger]: + for trigger in self.triggers: + if ( + trigger.column_name == column_name + and trigger.references_table == references_table + ): + return trigger + + return None + + +@dataclasses.dataclass +class Index: + indexname: str + indexdef: str + + def __post_init__(self): + """ + An example DDL statement which will be parsed: + + .. code-block:: sql + + CREATE INDEX some_index_name + ON some_schema.some_table + USING some_index_type (some_column_name) + + If the column name is the same as a Postgres data type, then Postgres + wraps the column name in quotes. For example, ``"time"`` instead of + ``time``. + + """ + pat = re.compile( + r"""^CREATE[ ](?:(?PUNIQUE)[ ])?INDEX[ ]\w+?[ ] + ON[ ].+?[ ]USING[ ](?P\w+?)[ ] + \(\"?(?P\w+?\"?)\)""", + re.VERBOSE, + ) + match = re.match(pat, self.indexdef) + if match is None: + self.column_name = None + self.unique = None + self.method = None + self.warnings = [f"{self.indexdef};"] + else: + groups = match.groupdict() + + self.column_name = groups["column_name"].lstrip('"').rstrip('"') + self.unique = "unique" in groups + self.method = INDEX_METHOD_MAP[groups["method"]] + self.warnings = [] + + +@dataclasses.dataclass +class TableIndexes: + """ + All of the indexes for a certain table in the database. + """ + + tablename: str + indexes: list[Index] + + def get_column_index(self, column_name: str) -> Optional[Index]: + return next( + (i for i in self.indexes if i.column_name == column_name), None + ) + + def get_warnings(self) -> list[str]: + return list( + itertools.chain(*[index.warnings for index in self.indexes]) + ) + + +@dataclasses.dataclass +class OutputSchema: + """ + Represents the schema which will be printed out. + :param imports: + e.g. ["from piccolo.table import Table"] + :param warnings: + e.g. ["some_table.some_column unrecognised_type"] + :param index_warnings: + Warnings if column indexes can't be parsed. + :param trigger_warnings: + Warnings if triggers for certain columns can't be found. + :param tables: + e.g. ["class MyTable(Table): ..."] + """ + + imports: list[str] = dataclasses.field(default_factory=list) + warnings: list[str] = dataclasses.field(default_factory=list) + index_warnings: list[str] = dataclasses.field(default_factory=list) + trigger_warnings: list[str] = dataclasses.field(default_factory=list) + tables: list[type[Table]] = dataclasses.field(default_factory=list) + + def get_table_with_name(self, tablename: str) -> Optional[type[Table]]: + """ + Used to search for a table by name. + """ + tablename = _snake_to_camel(tablename) + try: + return next( + table for table in self.tables if table.__name__ == tablename + ) + except StopIteration: + return None + + def __radd__(self, value: OutputSchema) -> OutputSchema: + if isinstance(value, int): + return self + value.imports.extend(self.imports) + value.warnings.extend(self.warnings) + value.index_warnings.extend(self.index_warnings) + value.trigger_warnings.extend(self.trigger_warnings) + value.tables.extend(self.tables) + return value + + def __add__(self, value: OutputSchema) -> OutputSchema: + self.imports.extend(value.imports) + self.warnings.extend(value.warnings) + self.index_warnings.extend(value.index_warnings) + self.trigger_warnings.extend(value.trigger_warnings) + self.tables.extend(value.tables) + return self + + +COLUMN_TYPE_MAP: dict[str, type[Column]] = { + "bigint": BigInt, + "boolean": Boolean, + "bytea": Bytea, + "character varying": Varchar, + "date": Date, + "integer": Integer, + "interval": Interval, + "json": JSON, + "jsonb": JSONB, + "numeric": Numeric, + "real": Real, + "double precision": DoublePrecision, + "smallint": SmallInt, + "text": Text, + "timestamp with time zone": Timestamptz, + "timestamp without time zone": Timestamp, + "uuid": UUID, +} + +# Re-map for Cockroach compatibility. +COLUMN_TYPE_MAP_COCKROACH: dict[str, type[Column]] = { + **COLUMN_TYPE_MAP, + **{"integer": BigInt, "json": JSONB}, +} + +COLUMN_DEFAULT_PARSER: dict[type[Column], Any] = { + BigInt: re.compile(r"^'?(?P-?[0-9]\d*)'?(?:::bigint)?$"), + Boolean: re.compile(r"^(?Ptrue|false)$"), + Bytea: re.compile(r"'(?P.*)'::bytea$"), + DoublePrecision: re.compile(r"(?P[+-]?(?:[0-9]*[.])?[0-9]+)"), + Varchar: re.compile(r"^'(?P.*)'::character varying$"), + Date: re.compile(r"^(?P(?:\d{4}-\d{2}-\d{2})|CURRENT_DATE)$"), + Integer: re.compile(r"^(?P-?\d+)$"), + Interval: re.compile( + r"""^ + (?:')? + (?: + (?:(?P\d+)[ ]y(?:ear(?:s)?)?\b) | + (?:(?P\d+)[ ]m(?:onth(?:s)?)?\b) | + (?:(?P\d+)[ ]w(?:eek(?:s)?)?\b) | + (?:(?P\d+)[ ]d(?:ay(?:s)?)?\b) | + (?: + (?: + (?:(?P\d+)[ ]h(?:our(?:s)?)?\b) | + (?:(?P\d+)[ ]m(?:inute(?:s)?)?\b) | + (?:(?P\d+)[ ]s(?:econd(?:s)?)?\b) + ) | + (?: + (?P-?\d{2}:\d{2}:\d{2}))?\b) + ) + +(?Pago)? + (?:'::interval)? + $""", + re.X, + ), + JSON: re.compile(r"^'(?P.*)'::json$"), + JSONB: re.compile(r"^'(?P.*)'::jsonb$"), + Numeric: re.compile(r"(?P\d+)"), + Real: re.compile(r"^(?P-?[0-9]\d*(?:\.\d+)?)$"), + SmallInt: re.compile(r"^'?(?P-?[0-9]\d*)'?(?:::integer)?$"), + Text: re.compile(r"^'(?P.*)'::text$"), + Timestamp: re.compile( + r"""^ + (?P + (?:\d{4}-\d{2}-\d{2}[ ]\d{2}:\d{2}:\d{2}) | + CURRENT_TIMESTAMP + ) + $""", + re.VERBOSE, + ), + Timestamptz: re.compile( + r"""^ + (?P + (?:\d{4}-\d{2}-\d{2}[ ]\d{2}:\d{2}:\d{2}(?:\.\d+)?-\d{2}) | + CURRENT_TIMESTAMP + ) + $""", + re.VERBOSE, + ), + UUID: None, + Serial: None, + ForeignKey: None, +} + +# Re-map for Cockroach compatibility. +COLUMN_DEFAULT_PARSER_COCKROACH: dict[type[Column], Any] = { + **COLUMN_DEFAULT_PARSER, + BigInt: re.compile(r"^(?P-?\d+)$"), +} + + +def get_column_default( + column_type: type[Column], column_default: str, engine_type: str +) -> Any: + if engine_type == "cockroach": + pat = COLUMN_DEFAULT_PARSER_COCKROACH.get(column_type) + else: + pat = COLUMN_DEFAULT_PARSER.get(column_type) + + # Strip extra, incorrect typing stuff from Cockroach. + column_default = column_default.split(":::", 1)[0] + + if pat is None: + return None + else: + match = re.match(pat, column_default) + if match is not None: + value = match.groupdict() + + if column_type is Boolean: + return value["value"] == "true" + elif column_type is Interval: + kwargs = {} + for period in [ + "years", + "months", + "weeks", + "days", + "hours", + "minutes", + "seconds", + ]: + period_match = value.get(period, 0) + if period_match: + kwargs[period] = int(period_match) + digits = value["digits"] + if digits: + kwargs.update( + dict( + zip( + ["hours", "minutes", "seconds"], + [int(v) for v in digits.split(":")], + ) + ) + ) + + return IntervalCustom(**kwargs) + elif column_type is JSON or column_type is JSONB: + return json.loads(value["value"]) + elif column_type is UUID: + return uuid.uuid4 + elif column_type is Date: + return ( + date.today + if value["value"] == "CURRENT_DATE" + else defaults.date.DateCustom( + *[int(v) for v in value["value"].split("-")] + ) + ) + elif column_type is Bytea: + return value["value"].encode("utf8") + elif column_type is Timestamp: + return ( + datetime.now + if value["value"] == "CURRENT_TIMESTAMP" + else datetime.fromtimestamp(float(value["value"])) + ) + elif column_type is Timestamptz: + return ( + datetime.now + if value["value"] == "CURRENT_TIMESTAMP" + else datetime.fromtimestamp(float(value["value"])) + ) + else: + return column_type.value_type(value["value"]) + + +INDEX_METHOD_MAP: dict[str, IndexMethod] = { + "btree": IndexMethod.btree, + "hash": IndexMethod.hash, + "gist": IndexMethod.gist, + "gin": IndexMethod.gin, +} + + +# 'Indices' seems old-fashioned and obscure in this context. +async def get_indexes( # noqa: E302 + table_class: type[Table], tablename: str, schema_name: str = "public" +) -> TableIndexes: + """ + Get all of the constraints for a table. + + :param table_class: + Any Table subclass - just used to execute raw queries on the database. + + """ + indexes = await table_class.raw( + ( + "SELECT indexname, indexdef " + "FROM pg_indexes " + "WHERE schemaname = {} " + "AND tablename = {}; " + ), + schema_name, + tablename, + ) + + return TableIndexes( + tablename=tablename, + indexes=[Index(**i) for i in indexes], + ) + + +async def get_fk_triggers( + table_class: type[Table], tablename: str, schema_name: str = "public" +) -> TableTriggers: + """ + Get all of the constraints for a table. + + :param table_class: + Any Table subclass - just used to execute raw queries on the database. + + """ + # TODO - Move this query to `piccolo.query.constraints` or use: + # `piccolo.query.constraints.referential_constraints` + triggers = await table_class.raw( + ( + "SELECT tc.constraint_name, " + " tc.constraint_type, " + " tc.table_name, " + " kcu.column_name, " + " rc.update_rule AS on_update, " + " rc.delete_rule AS on_delete, " + " ccu.table_name AS references_table, " + " ccu.column_name AS references_column " + "FROM information_schema.table_constraints tc " + "LEFT JOIN information_schema.key_column_usage kcu " + " ON tc.constraint_catalog = kcu.constraint_catalog " + " AND tc.constraint_schema = kcu.constraint_schema " + " AND tc.constraint_name = kcu.constraint_name " + "LEFT JOIN information_schema.referential_constraints rc " + " ON tc.constraint_catalog = rc.constraint_catalog " + " AND tc.constraint_schema = rc.constraint_schema " + " AND tc.constraint_name = rc.constraint_name " + "LEFT JOIN information_schema.constraint_column_usage ccu " + " ON rc.unique_constraint_catalog = ccu.constraint_catalog " + " AND rc.unique_constraint_schema = ccu.constraint_schema " + " AND rc.unique_constraint_name = ccu.constraint_name " + "WHERE lower(tc.constraint_type) in ('foreign key')" + "AND tc.table_schema = {} " + "AND tc.table_name = {}; " + ), + schema_name, + tablename, + ) + return TableTriggers( + tablename=tablename, + triggers=[Trigger(**i) for i in triggers], + ) + + +async def get_constraints( + table_class: type[Table], tablename: str, schema_name: str = "public" +) -> TableConstraints: + """ + Get all of the constraints for a table. + + :param table_class: + Any Table subclass - just used to execute raw queries on the database. + :param tablename: + Name of the table. + :param schema_name: + Name of the schema. + + """ + constraints = await table_class.raw( + ( + "SELECT tc.constraint_name, tc.constraint_type, kcu.column_name, tc.constraint_schema " # noqa: E501 + "FROM information_schema.table_constraints tc " + "LEFT JOIN information_schema.key_column_usage kcu " + "ON tc.constraint_name = kcu.constraint_name " + "WHERE tc.table_schema = {} " + "AND tc.table_name = {} " + ), + schema_name, + tablename, + ) + return TableConstraints( + tablename=tablename, + constraints=[Constraint(**i) for i in constraints], + ) + + +async def get_tablenames( + table_class: type[Table], schema_name: str = "public" +) -> list[str]: + """ + Get the tablenames for the schema. + + :param table_class: + Any Table subclass - just used to execute raw queries on the database. + :param schema_name: + Name of the schema. + :returns: + A list of tablenames for the given schema. + + """ + return [ + i["tablename"] + for i in await table_class.raw( + ( + "SELECT tablename FROM pg_catalog.pg_tables WHERE " + "schemaname = {}" + ), + schema_name, + ).run() + ] + + +async def get_table_schema( + table_class: type[Table], tablename: str, schema_name: str = "public" +) -> list[RowMeta]: + """ + Get the schema from the database. + + :param table_class: + Any Table subclass - just used to execute raw queries on the database. + :param tablename: + The name of the table whose schema we want from the database. + :param schema_name: + A Postgres database can have multiple schemas, this is the name of the + one you're interested in. + :returns: + A list, with each item containing information about a column in the + table. + + """ + row_meta_list = await table_class.raw( + ( + f"SELECT {RowMeta.get_column_name_str()} FROM " + "information_schema.columns " + "WHERE table_schema = {} " + "AND table_name = {}" + ), + schema_name, + tablename, + ).run() + return [RowMeta(**row_meta) for row_meta in row_meta_list] + + +async def get_foreign_key_reference( + table_class: type[Table], constraint_name: str, constraint_schema: str +) -> ConstraintTable: + """ + Retrieve the name of the table that a foreign key is referencing. + """ + response = await table_class.raw( + ( + "SELECT table_name, table_schema " + "FROM information_schema.constraint_column_usage " + "WHERE constraint_name = {} AND constraint_schema = {};" + ), + constraint_name, + constraint_schema, + ) + if len(response) > 0: + return ConstraintTable( + name=response[0]["table_name"], schema=response[0]["table_schema"] + ) + else: + return ConstraintTable() + + +async def create_table_class_from_db( + table_class: type[Table], + tablename: str, + schema_name: str, + engine_type: str, +) -> OutputSchema: + output_schema = OutputSchema() + + indexes = await get_indexes( + table_class=table_class, tablename=tablename, schema_name=schema_name + ) + output_schema.index_warnings.extend(indexes.get_warnings()) + + constraints = await get_constraints( + table_class=table_class, tablename=tablename, schema_name=schema_name + ) + triggers = await get_fk_triggers( + table_class=table_class, tablename=tablename, schema_name=schema_name + ) + table_schema = await get_table_schema( + table_class=table_class, tablename=tablename, schema_name=schema_name + ) + + columns: dict[str, Column] = {} + + for pg_row_meta in table_schema: + data_type = pg_row_meta.data_type + + if engine_type == "cockroach": + column_type = COLUMN_TYPE_MAP_COCKROACH.get(data_type, None) + else: + column_type = COLUMN_TYPE_MAP.get(data_type, None) + + column_name = pg_row_meta.column_name + column_default = pg_row_meta.column_default + if not column_type: + output_schema.warnings.append( + f"{tablename}.{column_name} ['{data_type}']" + ) + column_type = Column + + kwargs: dict[str, Any] = { + "null": pg_row_meta.is_nullable == "YES", + "unique": constraints.is_unique(column_name=column_name), + } + + index = indexes.get_column_index(column_name=column_name) + if index is not None: + kwargs["index"] = True + kwargs["index_method"] = index.method + + if constraints.is_primary_key(column_name=column_name): + kwargs["primary_key"] = True + if column_type == Integer: + column_type = Serial + if column_type == BigInt: + column_type = Serial + # column_type = BigSerial + + if constraints.is_foreign_key(column_name=column_name): + fk_constraint_table = constraints.get_foreign_key_constraint_name( + column_name=column_name + ) + column_type = ForeignKey + constraint_table = await get_foreign_key_reference( + table_class=table_class, + constraint_name=fk_constraint_table.name, + constraint_schema=fk_constraint_table.schema, + ) + if constraint_table.name: + referenced_table: Union[str, Optional[type[Table]]] + + if constraint_table.name == tablename: + referenced_output_schema = output_schema + referenced_table = "self" + else: + referenced_output_schema = ( + await create_table_class_from_db( + table_class=table_class, + tablename=constraint_table.name, + schema_name=constraint_table.schema, + engine_type=engine_type, + ) + ) + referenced_table = ( + referenced_output_schema.get_table_with_name( + tablename=constraint_table.name + ) + ) + kwargs["references"] = ( + referenced_table + if referenced_table is not None + else ForeignKeyPlaceholder + ) + + trigger = triggers.get_column_ref_trigger( + column_name, constraint_table.name + ) + if trigger: + kwargs["on_update"] = OnUpdate(trigger.on_update) + kwargs["on_delete"] = OnDelete(trigger.on_delete) + else: + output_schema.trigger_warnings.append( + f"{tablename}.{column_name}" + ) + + output_schema = sum( # type: ignore + [output_schema, referenced_output_schema] # type: ignore + ) # type: ignore + else: + kwargs["references"] = ForeignKeyPlaceholder + + output_schema.imports.append( + "from piccolo.columns.column_types import " + + column_type.__name__ # type: ignore + ) + + if column_type is Varchar: + kwargs["length"] = pg_row_meta.character_maximum_length + elif isinstance(column_type, Numeric): + radix = pg_row_meta.numeric_precision_radix + if radix: + precision = int(str(pg_row_meta.numeric_precision), radix) + scale = int(str(pg_row_meta.numeric_scale), radix) + kwargs["digits"] = (precision, scale) + else: + kwargs["digits"] = None + + if column_default: + default_value = get_column_default( + column_type, column_default, engine_type + ) + if default_value: + kwargs["default"] = default_value + + column = column_type(**kwargs) # type: ignore + + serialised_params = serialise_params(column._meta.params) + for extra_import in serialised_params.extra_imports: + output_schema.imports.append(extra_import.__repr__()) + + columns[column_name] = column + + table = create_table_class( + class_name=_snake_to_camel(tablename), + class_kwargs={"tablename": tablename, "schema": schema_name}, + class_members=columns, + ) + output_schema.tables.append(table) + return output_schema + + +async def get_output_schema( + schema_name: str = "public", + include: Optional[list[str]] = None, + exclude: Optional[list[str]] = None, + engine: Optional[Engine] = None, +) -> OutputSchema: + """ + :param schema_name: + Name of the schema. + :param include: + Optional list of table names. Only creates the specified tables. + :param exclude: + Optional list of table names. excludes the specified tables. + :param engine: + The ``Engine`` instance to use for making database queries. If not + specified, then ``engine_finder`` is used to get the engine from + ``piccolo_conf.py``. + :returns: + OutputSchema + """ + if engine is None: + engine = engine_finder() + + if exclude is None: + exclude = [] + + if engine is None: + raise ValueError( + "Unable to find the engine - make sure piccolo_conf is on the " + "path." + ) + + if not isinstance(engine, PostgresEngine): + raise ValueError( + "This feature is currently only supported in Postgres." + ) + + class Schema(Table, db=engine): + """ + Just used for making raw queries on the db. + """ + + pass + + if not include: + include = await get_tablenames(Schema, schema_name=schema_name) + + tablenames = [ + tablename for tablename in include if tablename not in exclude + ] + table_coroutines = ( + create_table_class_from_db( + table_class=Schema, + tablename=tablename, + schema_name=schema_name, + engine_type=engine.engine_type, + ) + for tablename in tablenames + ) + output_schemas = await asyncio.gather( + *table_coroutines, return_exceptions=True + ) + + # handle exceptions + exceptions = [] + for obj, tablename in zip(output_schemas, tablenames): + if isinstance(obj, Exception): + exceptions.append((obj, tablename)) + + if exceptions: + raise GenerateError( + [ + type(e)( + f"Exception occurred while generating" + f" `{tablename}` table: {e}" + ) + for e, tablename in exceptions + ] + ) + + # Merge all the output schemas to a single OutputSchema object + output_schema: OutputSchema = sum(output_schemas) # type: ignore + + # Sort the tables based on their ForeignKeys. + output_schema.tables = sort_table_classes( + sorted(output_schema.tables, key=lambda x: x._meta.tablename) + ) + output_schema.imports = sorted(set(output_schema.imports)) + + return output_schema + + +# This is currently a beta version, and can be improved. However, having +# something working is still useful for people migrating large schemas to +# Piccolo. +async def generate(schema_name: str = "public"): + """ + Automatically generates Piccolo Table classes by introspecting the + database. Please check the generated code in case there are errors. + + """ + output_schema = await get_output_schema(schema_name=schema_name) + + output = output_schema.imports + [ + i._table_str(excluded_params=["choices"]) for i in output_schema.tables + ] + + if output_schema.warnings: + warning_str = "\n".join(output_schema.warnings) + + output.append('"""') + output.append( + "WARNING: Unrecognised column types, added `Column` as a " + "placeholder:" + ) + output.append(warning_str) + output.append('"""') + + if output_schema.index_warnings: + warning_str = "\n".join(set(output_schema.index_warnings)) + + output.append('"""') + output.append("WARNING: Unable to parse the following indexes:") + output.append(warning_str) + output.append('"""') + + if output_schema.trigger_warnings: + warning_str = "\n".join(set(output_schema.trigger_warnings)) + + output.append('"""') + output.append( + "WARNING: Unable to find triggers for the following (used for " + "ON UPDATE, ON DELETE):" + ) + output.append(warning_str) + output.append('"""') + + nicely_formatted = black.format_str( + "\n".join(output), mode=black.FileMode(line_length=79) + ) + print(nicely_formatted) diff --git a/piccolo/apps/schema/commands/graph.py b/piccolo/apps/schema/commands/graph.py new file mode 100644 index 000000000..ffb09e761 --- /dev/null +++ b/piccolo/apps/schema/commands/graph.py @@ -0,0 +1,113 @@ +""" +Credit to the Django Extensions team for inspiring this tool. +""" + +import dataclasses +import os +import sys +from typing import Optional + +import jinja2 + +from piccolo.conf.apps import Finder + +TEMPLATE_DIRECTORY = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "templates" +) + +JINJA_ENV = jinja2.Environment( + loader=jinja2.FileSystemLoader(searchpath=TEMPLATE_DIRECTORY), +) + + +@dataclasses.dataclass +class GraphColumn: + name: str + type: str + + +@dataclasses.dataclass +class GraphTable: + name: str + columns: list[GraphColumn] + + +@dataclasses.dataclass +class GraphRelation: + table_a: str + table_b: str + label: str + + +def render_template(**kwargs): + template = JINJA_ENV.get_template("graphviz.dot.jinja") + return template.render(**kwargs) + + +def graph( + apps: str = "all", direction: str = "LR", output: Optional[str] = None +): + """ + Prints out a graphviz .dot file for your schema. + + :param apps: + The name of the apps to include. If 'all' is given then every app is + included. To specify multiple app names, separate them with commas. + For example --apps="app1,app2". + :param direction: + How the tables should be orientated - by default it's "LR" which is + left to right, so the graph will be landscape. The alternative is + "TB", which is top to bottom, so the graph will be portrait. + :param output: + If specified, rather than printing out the file contents, they'll be + written to this file. For example --output=graph.dot + + """ + finder = Finder() + app_names = finder.get_sorted_app_names() + + if apps != "all": + given_app_names = [i.strip() for i in apps.split(",")] + delta = set(given_app_names) - set(app_names) + if delta: + sys.exit(f"These apps aren't recognised: {', '.join(delta)}.") + app_names = given_app_names + + tables: list[GraphTable] = [] + relations: list[GraphRelation] = [] + + for app_name in app_names: + app_config = finder.get_app_config(app_name=app_name) + for table_class in app_config.table_classes: + tables.append( + GraphTable( + name=table_class.__name__, + columns=[ + GraphColumn( + name=i._meta.name, type=i.__class__.__name__ + ) + for i in table_class._meta.columns + ], + ) + ) + for fk_column in table_class._meta.foreign_key_columns: + reference_table_class = ( + fk_column._foreign_key_meta.resolved_references + ) + relations.append( + GraphRelation( + table_a=table_class.__name__, + table_b=reference_table_class.__name__, + label=fk_column._meta.name, + ) + ) + + contents = render_template( + tables=tables, relations=relations, direction=direction + ) + + if output is None: + print(contents) + else: + with open(output, "w") as f: + f.write(contents) diff --git a/piccolo/apps/schema/commands/templates/graphviz.dot.jinja b/piccolo/apps/schema/commands/templates/graphviz.dot.jinja new file mode 100644 index 000000000..e014fa89b --- /dev/null +++ b/piccolo/apps/schema/commands/templates/graphviz.dot.jinja @@ -0,0 +1,53 @@ +digraph model_graph { + fontname = "Roboto" + fontsize = 8 + splines = true + rankdir = "{{ direction }}"; + + node [ + fontname = "Roboto" + fontsize = 8 + shape = "plaintext" + ] + + edge [ + fontname = "Roboto" + fontsize = 8 + ] + + // Tables + {% for table in tables %} + TABLE_{{ table.name }} [label=< + + + + + + {% for column in table.columns %} + + + + + {% endfor %} +
    + + {{ table.name }} + +
    + + {{ column.name }} + + + + {{ column.type }} + +
    + >] + {% endfor %} + + // Relations + {% for relation in relations %} + TABLE_{{ relation.table_a }} -> TABLE_{{ relation.table_b }} + [label="{{ relation.label }}"] [arrowhead=none, arrowtail=dot, dir=both]; + {% endfor %} +} diff --git a/piccolo/apps/schema/piccolo_app.py b/piccolo/apps/schema/piccolo_app.py new file mode 100644 index 000000000..ae5449d65 --- /dev/null +++ b/piccolo/apps/schema/piccolo_app.py @@ -0,0 +1,16 @@ +from piccolo.conf.apps import AppConfig, Command + +from .commands.generate import generate +from .commands.graph import graph + +APP_CONFIG = AppConfig( + app_name="schema", + migrations_folder_path="", + commands=[ + Command(callable=generate, aliases=["gen", "create", "new", "mirror"]), + Command( + callable=graph, + aliases=["map", "visualise", "vizualise", "viz", "vis"], + ), + ], +) diff --git a/piccolo/apps/shell/commands/run.py b/piccolo/apps/shell/commands/run.py index 225025f93..4b24f7fd6 100644 --- a/piccolo/apps/shell/commands/run.py +++ b/piccolo/apps/shell/commands/run.py @@ -1,7 +1,6 @@ import sys -import typing as t -from piccolo.conf.apps import AppConfig, AppRegistry, Finder +from piccolo.conf.apps import Finder from piccolo.table import Table try: @@ -13,9 +12,7 @@ IPYTHON = False -def start_ipython_shell( - **tables: t.Dict[str, t.Type[Table]] -): # pragma: no cover +def start_ipython_shell(**tables: type[Table]): # pragma: no cover if not IPYTHON: sys.exit( "Install iPython using `pip install ipython` to use this feature." @@ -26,24 +23,23 @@ def start_ipython_shell( if table_class_name not in existing_global_names: globals()[table_class_name] = table_class - IPython.embed(using=_asyncio_runner, colors="neutral") + IPython.embed(using=_asyncio_runner, colors="neutral") # type: ignore -def run(): +def run() -> None: """ Runs an iPython shell, and automatically imports all of the Table classes from your project. """ - app_registry: AppRegistry = Finder().get_app_registry() + app_registry = Finder().get_app_registry() tables = {} - spacer = "-------" - if app_registry.app_configs: + spacer = "-------" + print(spacer) for app_name, app_config in app_registry.app_configs.items(): - app_config: AppConfig = app_config print(f"Importing {app_name} tables:") if app_config.table_classes: for table_class in sorted( diff --git a/piccolo/apps/sql_shell/commands/run.py b/piccolo/apps/sql_shell/commands/run.py index dd3c09d17..a666321a4 100644 --- a/piccolo/apps/sql_shell/commands/run.py +++ b/piccolo/apps/sql_shell/commands/run.py @@ -1,22 +1,20 @@ import os import signal import subprocess -import typing as t +import sys +from typing import cast from piccolo.engine.finder import engine_finder from piccolo.engine.postgres import PostgresEngine from piccolo.engine.sqlite import SQLiteEngine -if t.TYPE_CHECKING: # pragma: no cover - from piccolo.engine.base import Engine - -def run(): +def run() -> None: """ Launch the SQL shell for the configured engine. For Postgres this will be psql, and for SQLite it will be sqlite3. """ - engine: t.Optional[Engine] = engine_finder() + engine = engine_finder() if engine is None: raise ValueError( @@ -26,27 +24,27 @@ def run(): # Heavily inspired by Django's dbshell command if isinstance(engine, PostgresEngine): - engine: PostgresEngine = engine + engine = cast(PostgresEngine, engine) args = ["psql"] - host = engine.config.get("host") - port = engine.config.get("port") - user = engine.config.get("user") - password = engine.config.get("password") - database = engine.config.get("database") + config = engine.config - if user: - args += ["-U", user] - if host: - args += ["-h", host] - if port: - args += ["-p", str(port)] - args += [database] + if dsn := config.get("dsn"): + args += [dsn] + else: + if user := config.get("user"): + args += ["-U", user] + if host := config.get("host"): + args += ["-h", host] + if port := config.get("port"): + args += ["-p", str(port)] + if database := config.get("database"): + args += [database] sigint_handler = signal.getsignal(signal.SIGINT) subprocess_env = os.environ.copy() - if password: + if password := config.get("password"): subprocess_env["PGPASSWORD"] = str(password) try: # Allow SIGINT to pass to psql to abort queries. @@ -58,8 +56,11 @@ def run(): signal.signal(signal.SIGINT, sigint_handler) elif isinstance(engine, SQLiteEngine): - engine: SQLiteEngine = engine + engine = cast(SQLiteEngine, engine) + + database = cast(str, engine.connection_kwargs.get("database")) + if not database: + sys.exit("Unable to determine which database to connect to.") + print("Enter .quit to exit") - subprocess.run( - ["sqlite3", engine.connection_kwargs.get("database")], check=True - ) + subprocess.run(["sqlite3", database], check=True) diff --git a/piccolo/apps/tester/__init__.py b/piccolo/apps/tester/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/piccolo/apps/tester/commands/__init__.py b/piccolo/apps/tester/commands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/piccolo/apps/tester/commands/run.py b/piccolo/apps/tester/commands/run.py new file mode 100644 index 000000000..8882fd645 --- /dev/null +++ b/piccolo/apps/tester/commands/run.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import os +import sys +from typing import Optional + +from piccolo.table import TABLE_REGISTRY + + +class set_env_var: + def __init__(self, var_name: str, temp_value: str): + """ + Temporarily set an environment variable. + + :param var_name: + The name of the environment variable to temporarily change. + :temp_value: + The value that the environment variable will temporarily be set to, + before being reset to it's pre-existing value. + + """ + self.var_name = var_name + self.temp_value = temp_value + + def set_var(self, value: str): + os.environ[self.var_name] = value + + def get_var(self) -> Optional[str]: + return os.environ.get(self.var_name) + + def __enter__(self): + self.existing_value = self.get_var() + self.set_var(self.temp_value) + + def __exit__(self, *args): + if self.existing_value is None: + del os.environ[self.var_name] + else: + self.set_var(self.existing_value) + + +def run_pytest(pytest_args: list[str]) -> int: # pragma: no cover + try: + import pytest + except ImportError: + sys.exit( + "Couldn't find pytest. Please use `pip install 'piccolo[pytest]' " + "to use this feature." + ) + + return pytest.main(pytest_args) + + +def refresh_db(): + for table_class in TABLE_REGISTRY: + # In case any table classes were imported before we set the + # environment variable. + table_class._meta.refresh_db() + + +def run( + pytest_args: str = "", piccolo_conf: str = "piccolo_conf_test" +) -> None: + """ + Run your unit test suite using Pytest. + + While running, it sets the ``PICCOLO_TEST_RUNNER`` environment variable to + ``'True'``, in case any other code needs to be aware of this. + + :param piccolo_conf: + The piccolo_conf module to use when running your tests. This will + contain the database settings you want to use. For example + `my_folder.piccolo_conf_test`. + :param pytest_args: + Any options you want to pass to Pytest. For example + `piccolo tester run --pytest_args="-s"`. + + """ + with set_env_var(var_name="PICCOLO_CONF", temp_value=piccolo_conf): + refresh_db() + args = pytest_args.split(" ") + + with set_env_var(var_name="PICCOLO_TEST_RUNNER", temp_value="True"): + sys.exit(run_pytest(args)) diff --git a/piccolo/apps/tester/piccolo_app.py b/piccolo/apps/tester/piccolo_app.py new file mode 100644 index 000000000..1bda3fd0c --- /dev/null +++ b/piccolo/apps/tester/piccolo_app.py @@ -0,0 +1,11 @@ +from piccolo.conf.apps import AppConfig + +from .commands.run import run + +APP_CONFIG = AppConfig( + app_name="tester", + migrations_folder_path="", + table_classes=[], + migration_dependencies=[], + commands=[run], +) diff --git a/piccolo/apps/user/commands/change_password.py b/piccolo/apps/user/commands/change_password.py index b4106333f..22c65e6de 100644 --- a/piccolo/apps/user/commands/change_password.py +++ b/piccolo/apps/user/commands/change_password.py @@ -16,7 +16,7 @@ def change_password(): password = get_password() confirmed_password = get_confirmed_password() - if not password == confirmed_password: + if password != confirmed_password: sys.exit("Passwords don't match!") BaseUser.update_password_sync(user=username, password=password) diff --git a/piccolo/apps/user/commands/change_permissions.py b/piccolo/apps/user/commands/change_permissions.py index 3f2672ed9..63ae90991 100644 --- a/piccolo/apps/user/commands/change_permissions.py +++ b/piccolo/apps/user/commands/change_permissions.py @@ -1,17 +1,17 @@ -import typing as t +from typing import TYPE_CHECKING, Optional, Union from piccolo.apps.user.tables import BaseUser from piccolo.utils.warnings import Level, colored_string -if t.TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from piccolo.columns import Column async def change_permissions( username: str, - admin: t.Optional[bool] = None, - superuser: t.Optional[bool] = None, - active: t.Optional[bool] = None, + admin: Optional[bool] = None, + superuser: Optional[bool] = None, + active: Optional[bool] = None, ): """ Change a user's permissions. @@ -34,7 +34,7 @@ async def change_permissions( ) return - params: t.Dict[t.Union[Column, str], bool] = {} + params: dict[Union[Column, str], bool] = {} if admin is not None: params[BaseUser.admin] = admin @@ -45,6 +45,12 @@ async def change_permissions( if active is not None: params[BaseUser.active] = active - await BaseUser.update(params).where(BaseUser.username == username).run() + if params: + await BaseUser.update(params).where( + BaseUser.username == username + ).run() + else: + print(colored_string("No changes detected", level=Level.medium)) + return print(f"Updated permissions for {username}") diff --git a/piccolo/apps/user/commands/create.py b/piccolo/apps/user/commands/create.py index 11f1e7bc7..decea97c1 100644 --- a/piccolo/apps/user/commands/create.py +++ b/piccolo/apps/user/commands/create.py @@ -1,5 +1,6 @@ import sys from getpass import getpass, getuser +from typing import Optional from piccolo.apps.user.tables import BaseUser @@ -55,26 +56,31 @@ def get_is_active() -> bool: return active == "y" -def create(): +def create( + username: Optional[str] = None, + email: Optional[str] = None, + password: Optional[str] = None, + is_admin: Optional[bool] = None, + is_superuser: Optional[bool] = None, + is_active: Optional[bool] = None, +): """ Create a new user. """ - username = get_username() - email = get_email() - password = get_password() - confirmed_password = get_confirmed_password() + username = get_username() if username is None else username + email = get_email() if email is None else email + if password is None: + password = get_password() + confirmed_password = get_confirmed_password() - if not password == confirmed_password: - sys.exit("Passwords don't match!") + if password != confirmed_password: + sys.exit("Passwords don't match!") - if len(password) < 4: - sys.exit("The password is too short") + is_admin = get_is_admin() if is_admin is None else is_admin + is_superuser = get_is_superuser() if is_superuser is None else is_superuser + is_active = get_is_active() if is_active is None else is_active - is_admin = get_is_admin() - is_superuser = get_is_superuser() - is_active = get_is_active() - - user = BaseUser( + user = BaseUser.create_user_sync( username=username, password=password, admin=is_admin, @@ -82,6 +88,5 @@ def create(): active=is_active, superuser=is_superuser, ) - user.save().run_sync() print(f"Created User {user.id}") diff --git a/piccolo/apps/user/commands/list.py b/piccolo/apps/user/commands/list.py new file mode 100644 index 000000000..88f8cb294 --- /dev/null +++ b/piccolo/apps/user/commands/list.py @@ -0,0 +1,71 @@ +from typing import Any + +from piccolo.apps.user.tables import BaseUser +from piccolo.columns import Column +from piccolo.utils.printing import print_dict_table + +ORDER_BY_COLUMN_NAMES = [ + i._meta.name for i in BaseUser.all_columns(exclude=[BaseUser.password]) +] + + +async def get_users( + order_by: Column, ascending: bool, limit: int, page: int +) -> list[dict[str, Any]]: + return ( + await BaseUser.select( + *BaseUser.all_columns(exclude=[BaseUser.password]) + ) + .order_by( + order_by, + ascending=ascending, + ) + .limit(limit) + .offset(limit * (page - 1)) + ) + + +async def list_users( + limit: int = 20, page: int = 1, order_by: str = "username" +): + """ + List existing users. + + :param limit: + The maximum number of users to list. + :param page: + Lets you paginate through the list of users. + :param order_by: + The column used to order the results. Prefix with '-' for descending + order. + + """ + if page < 1: + raise ValueError("The page number must > 0.") + + if limit < 1: + raise ValueError("The limit number must be > 0.") + + ascending = True + if order_by.startswith("-"): + ascending = False + order_by = order_by[1:] + + if order_by not in ORDER_BY_COLUMN_NAMES: + raise ValueError( + "The order_by argument must be one of the following: " + + ", ".join(ORDER_BY_COLUMN_NAMES) + ) + + users = await get_users( + order_by=BaseUser._meta.get_column_by_name(order_by), + ascending=ascending, + limit=limit, + page=page, + ) + + if len(users) == 0: + print("No data") + return + + print_dict_table(users, header_separator=True) diff --git a/piccolo/apps/user/piccolo_app.py b/piccolo/apps/user/piccolo_app.py index b523b1b72..efa08d934 100644 --- a/piccolo/apps/user/piccolo_app.py +++ b/piccolo/apps/user/piccolo_app.py @@ -5,6 +5,7 @@ from .commands.change_password import change_password from .commands.change_permissions import change_permissions from .commands.create import create +from .commands.list import list_users from .tables import BaseUser CURRENT_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) @@ -19,6 +20,7 @@ migration_dependencies=[], commands=[ Command(callable=create, aliases=["new"]), + Command(callable=list_users, command_name="list", aliases=["ls"]), Command(callable=change_password, aliases=["password", "pass"]), Command(callable=change_permissions, aliases=["perm", "perms"]), ], diff --git a/piccolo/apps/user/piccolo_migrations/2019-11-14T21-52-21.py b/piccolo/apps/user/piccolo_migrations/2019-11-14T21-52-21.py index 600e54946..8205aef64 100644 --- a/piccolo/apps/user/piccolo_migrations/2019-11-14T21-52-21.py +++ b/piccolo/apps/user/piccolo_migrations/2019-11-14T21-52-21.py @@ -15,8 +15,7 @@ async def forwards(): "length": 100, "default": "", "null": False, - "primary": False, - "key": False, + "primary_key": False, "unique": True, "index": False, }, @@ -30,8 +29,7 @@ async def forwards(): "length": 255, "default": "", "null": False, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, @@ -45,8 +43,7 @@ async def forwards(): "length": 255, "default": "", "null": False, - "primary": False, - "key": False, + "primary_key": False, "unique": True, "index": False, }, @@ -59,8 +56,7 @@ async def forwards(): params={ "default": False, "null": False, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, @@ -73,8 +69,7 @@ async def forwards(): params={ "default": False, "null": False, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, diff --git a/piccolo/apps/user/piccolo_migrations/2020-06-11T21-38-55.py b/piccolo/apps/user/piccolo_migrations/2020-06-11T21-38-55.py index 73701adb0..b5dc10908 100644 --- a/piccolo/apps/user/piccolo_migrations/2020-06-11T21-38-55.py +++ b/piccolo/apps/user/piccolo_migrations/2020-06-11T21-38-55.py @@ -15,8 +15,7 @@ async def forwards(): "length": 255, "default": "", "null": True, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, @@ -31,8 +30,7 @@ async def forwards(): "length": 255, "default": "", "null": True, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, diff --git a/piccolo/apps/user/piccolo_migrations/2021-04-30T16-14-15.py b/piccolo/apps/user/piccolo_migrations/2021-04-30T16-14-15.py index 7737e6767..ac1a6ecd1 100644 --- a/piccolo/apps/user/piccolo_migrations/2021-04-30T16-14-15.py +++ b/piccolo/apps/user/piccolo_migrations/2021-04-30T16-14-15.py @@ -18,8 +18,7 @@ async def forwards(): params={ "default": False, "null": False, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, "index_method": IndexMethod.btree, @@ -35,8 +34,7 @@ async def forwards(): params={ "default": None, "null": True, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, "index_method": IndexMethod.btree, diff --git a/piccolo/apps/user/tables.py b/piccolo/apps/user/tables.py index 01da9c124..7990b9e07 100644 --- a/piccolo/apps/user/tables.py +++ b/piccolo/apps/user/tables.py @@ -1,23 +1,30 @@ """ A User model, used for authentication. """ + from __future__ import annotations +import datetime import hashlib +import logging import secrets -import typing as t +from typing import Any, Optional, Union from piccolo.columns import Boolean, Secret, Timestamp, Varchar +from piccolo.columns.column_types import Serial from piccolo.columns.readable import Readable from piccolo.table import Table from piccolo.utils.sync import run_sync +logger = logging.getLogger(__name__) + class BaseUser(Table, tablename="piccolo_user"): """ Provides a basic user, with authentication support. """ + id: Serial username = Varchar(length=100, unique=True) password = Secret(length=255) first_name = Varchar(null=True) @@ -41,11 +48,18 @@ class BaseUser(Table, tablename="piccolo_user"): help_text="When this user last logged in.", ) + _min_password_length = 6 + _max_password_length = 128 + # The number of hash iterations recommended by OWASP: + # https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html#pbkdf2 + _pbkdf2_iteration_count = 600_000 + def __init__(self, **kwargs): # Generating passwords upfront is expensive, so might need reworking. password = kwargs.get("password", None) if password: - kwargs["password"] = self.__class__.hash_password(password) + if not password.startswith("pbkdf2_sha256"): + kwargs["password"] = self.__class__.hash_password(password) super().__init__(**kwargs) @classmethod @@ -62,39 +76,89 @@ def get_readable(cls) -> Readable: ########################################################################### @classmethod - def update_password_sync(cls, user: t.Union[str, int], password: str): + def _validate_password(cls, password: str): + """ + Validate the raw password. Used by :meth:`update_password` and + :meth:`create_user`. + + :param password: + The raw password e.g. ``'hello123'``. + :raises ValueError: + If the password fails any of the criteria. + + """ + if not password: + raise ValueError("A password must be provided.") + + if len(password) < cls._min_password_length: + raise ValueError( + f"The password is too short. (min {cls._min_password_length})" + ) + + if len(password) > cls._max_password_length: + raise ValueError( + f"The password is too long. (max {cls._max_password_length})" + ) + + if password.startswith("pbkdf2_sha256"): + logger.warning( + "Tried to create a user with an already hashed password." + ) + raise ValueError("Do not pass a hashed password.") + + ########################################################################### + + @classmethod + def update_password_sync(cls, user: Union[str, int], password: str): + """ + A sync equivalent of :meth:`update_password`. + """ return run_sync(cls.update_password(user, password)) @classmethod - async def update_password(cls, user: t.Union[str, int], password: str): + async def update_password(cls, user: Union[str, int], password: str): """ - The password is the raw password string e.g. password123. + The password is the raw password string e.g. ``'password123'``. The user can be a user ID, or a username. """ if isinstance(user, str): clause = cls.username == user elif isinstance(user, int): - clause = cls.id == user # type: ignore + clause = cls.id == user else: raise ValueError( "The `user` arg must be a user id, or a username." ) + cls._validate_password(password=password) + password = cls.hash_password(password) - await cls.update().values({cls.password: password}).where(clause).run() + await cls.update({cls.password: password}).where(clause).run() ########################################################################### @classmethod def hash_password( - cls, password: str, salt: str = "", iterations: int = 10000 + cls, password: str, salt: str = "", iterations: Optional[int] = None ) -> str: """ Hashes the password, ready for storage, and for comparing during login. + + :raises ValueError: + If an excessively long password is provided. + """ - if salt == "": + if len(password) > cls._max_password_length: + logger.warning("Excessively long password provided.") + raise ValueError("The password is too long.") + + if not salt: salt = cls.get_salt() + + if iterations is None: + iterations = cls._pbkdf2_iteration_count + hashed = hashlib.pbkdf2_hmac( "sha256", bytes(password, encoding="utf-8"), @@ -103,56 +167,120 @@ def hash_password( ).hex() return f"pbkdf2_sha256${iterations}${salt}${hashed}" - def __setattr__(self, name: str, value: t.Any): + def __setattr__(self, name: str, value: Any): """ Make sure that if the password is set, it's stored in a hashed form. """ - if name == "password": - if not value.startswith("pbkdf2_sha256"): - value = self.__class__.hash_password(value) + if name == "password" and not value.startswith("pbkdf2_sha256"): + value = self.__class__.hash_password(value) super().__setattr__(name, value) @classmethod - def split_stored_password(cls, password: str) -> t.List[str]: + def split_stored_password(cls, password: str) -> list[str]: elements = password.split("$") if len(elements) != 4: raise ValueError("Unable to split hashed password") return elements + ########################################################################### + @classmethod - def login_sync(cls, username: str, password: str) -> t.Optional[int]: + def login_sync(cls, username: str, password: str) -> Optional[int]: """ - Returns the user_id if a match is found. + A sync equivalent of :meth:`login`. """ return run_sync(cls.login(username, password)) @classmethod - async def login(cls, username: str, password: str) -> t.Optional[int]: + async def login(cls, username: str, password: str) -> Optional[int]: """ - Returns the user_id if a match is found. + Make sure the user exists and the password is valid. If so, the + ``last_login`` value is updated in the database. + + :returns: + The id of the user if a match is found, otherwise ``None``. + """ - query = ( - cls.select() - .columns(cls._meta.primary_key, cls.password) - .where((cls.username == username)) + if (max_username_length := cls.username.length) and len( + username + ) > max_username_length: + logger.warning("Excessively long username provided.") + return None + + if len(password) > cls._max_password_length: + logger.warning("Excessively long password provided.") + return None + + response = ( + await cls.select(cls._meta.primary_key, cls.password) + .where(cls.username == username) .first() + .run() ) - response = await query.run() if not response: - # No match found + # No match found. We still call hash_password + # here to mitigate the ability to enumerate + # users via response timings + cls.hash_password(password) return None stored_password = response["password"] - algorithm, iterations, salt, hashed = cls.split_stored_password( + algorithm, iterations_, salt, hashed = cls.split_stored_password( stored_password ) + iterations = int(iterations_) + + if cls.hash_password(password, salt, iterations) == stored_password: + # If the password was hashed in an earlier Piccolo version, update + # it so it's hashed with the currently recommended number of + # iterations: + if iterations != cls._pbkdf2_iteration_count: + await cls.update_password(username, password) - if ( - cls.hash_password(password, salt, int(iterations)) - == stored_password - ): + await cls.update({cls.last_login: datetime.datetime.now()}).where( + cls.username == username + ) return response["id"] else: return None + + ########################################################################### + + @classmethod + def create_user_sync( + cls, username: str, password: str, **extra_params + ) -> BaseUser: + """ + A sync equivalent of :meth:`create_user`. + """ + return run_sync( + cls.create_user( + username=username, password=password, **extra_params + ) + ) + + @classmethod + async def create_user( + cls, username: str, password: str, **extra_params + ) -> BaseUser: + """ + Creates a new user, and saves it in the database. It is recommended to + use this rather than instantiating and saving ``BaseUser`` directly, as + we add extra validation. + + :raises ValueError: + If the username or password is invalid. + :returns: + The created ``BaseUser`` instance. + + """ + if not username: + raise ValueError("A username must be provided.") + + cls._validate_password(password=password) + + user = cls(username=username, password=password, **extra_params) + await user.save() + return user diff --git a/piccolo/columns/__init__.py b/piccolo/columns/__init__.py index eb6fc9a5f..12a258960 100644 --- a/piccolo/columns/__init__.py +++ b/piccolo/columns/__init__.py @@ -5,10 +5,13 @@ UUID, Array, BigInt, + BigSerial, Boolean, Bytea, Date, Decimal, + DoublePrecision, + Email, Float, ForeignKey, Integer, @@ -20,9 +23,51 @@ Serial, SmallInt, Text, + Time, Timestamp, Timestamptz, Varchar, ) from .combination import And, Or, Where +from .m2m import M2M from .reference import LazyTableReference + +__all__ = [ + "Column", + "ForeignKeyMeta", + "OnDelete", + "OnUpdate", + "Selectable", + "JSON", + "JSONB", + "UUID", + "Array", + "BigInt", + "BigSerial", + "Boolean", + "Bytea", + "Date", + "Decimal", + "DoublePrecision", + "Email", + "Float", + "ForeignKey", + "Integer", + "Interval", + "Numeric", + "PrimaryKey", + "Real", + "Secret", + "Serial", + "SmallInt", + "Text", + "Time", + "Timestamp", + "Timestamptz", + "Varchar", + "And", + "Or", + "Where", + "M2M", + "LazyTableReference", +] diff --git a/piccolo/columns/base.py b/piccolo/columns/base.py index 8fff7e600..885768bf2 100644 --- a/piccolo/columns/base.py +++ b/piccolo/columns/base.py @@ -4,11 +4,20 @@ import datetime import decimal import inspect -import typing as t import uuid -from abc import ABCMeta, abstractmethod -from dataclasses import dataclass, field +from collections.abc import Iterable +from dataclasses import dataclass, field, fields from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Optional, + TypedDict, + TypeVar, + Union, + cast, +) from piccolo.columns.choices import Choice from piccolo.columns.combination import Where @@ -32,15 +41,21 @@ NotLike, ) from piccolo.columns.reference import LazyTableReference -from piccolo.querystring import QueryString +from piccolo.querystring import QueryString, Selectable from piccolo.utils.warnings import colored_warning -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from piccolo.columns.column_types import ForeignKey + from piccolo.query.methods.select import Select from piccolo.table import Table class OnDelete(str, Enum): + """ + Used by :class:`ForeignKey ` to + specify the behaviour when a related row is deleted. + """ + cascade = "CASCADE" restrict = "RESTRICT" no_action = "NO ACTION" @@ -55,6 +70,11 @@ def __repr__(self): class OnUpdate(str, Enum): + """ + Used by :class:`ForeignKey ` to + specify the behaviour when a related row is updated. + """ + cascade = "CASCADE" restrict = "RESTRICT" no_action = "NO ACTION" @@ -68,20 +88,24 @@ def __repr__(self): return self.__str__() +ReferencedTable = TypeVar("ReferencedTable", bound="Table") + + @dataclass -class ForeignKeyMeta: - references: t.Union[t.Type[Table], LazyTableReference] +class ForeignKeyMeta(Generic[ReferencedTable]): + references: Union[type[ReferencedTable], LazyTableReference] on_delete: OnDelete on_update: OnUpdate - proxy_columns: t.List[Column] = field(default_factory=list) + target_column: Union[Column, str, None] + proxy_columns: list[Column] = field(default_factory=list) @property - def resolved_references(self) -> t.Type[Table]: + def resolved_references(self) -> type[Table]: """ - Evaluates the ``references`` attribute if it's a LazyTableReference, + Evaluates the ``references`` attribute if it's a ``LazyTableReference``, raising a ``ValueError`` if it fails, otherwise returns a ``Table`` subclass. - """ + """ # noqa: E501 from piccolo.table import Table if isinstance(self.references, LazyTableReference): @@ -92,19 +116,34 @@ def resolved_references(self) -> t.Type[Table]: return self.references else: raise ValueError( - "The references attribute is neither a Table sublclass or a " + "The references attribute is neither a Table subclass or a " "LazyTableReference instance." ) - def copy(self) -> ForeignKeyMeta: + @property + def resolved_target_column(self) -> Column: + if self.target_column is None: + return self.resolved_references._meta.primary_key + elif isinstance(self.target_column, Column): + return self.resolved_references._meta.get_column_by_name( + self.target_column._meta.name + ) + elif isinstance(self.target_column, str): + return self.resolved_references._meta.get_column_by_name( + self.target_column + ) + else: + raise ValueError("Unable to resolve target_column.") + + def copy(self) -> ForeignKeyMeta[ReferencedTable]: kwargs = self.__dict__.copy() kwargs.update(proxy_columns=self.proxy_columns.copy()) return self.__class__(**kwargs) - def __copy__(self) -> ForeignKeyMeta: + def __copy__(self) -> ForeignKeyMeta[ReferencedTable]: return self.copy() - def __deepcopy__(self, memo) -> ForeignKeyMeta: + def __deepcopy__(self, memo) -> ForeignKeyMeta[ReferencedTable]: """ We override deepcopy, as it's too slow if it has to recreate everything. @@ -126,19 +165,32 @@ class ColumnMeta: index: bool = False index_method: IndexMethod = IndexMethod.btree required: bool = False - help_text: t.Optional[str] = None - choices: t.Optional[t.Type[Enum]] = None + help_text: Optional[str] = None + choices: Optional[type[Enum]] = None + secret: bool = False + auto_update: Any = ... # Used for representing the table in migrations and the playground. - params: t.Dict[str, t.Any] = field(default_factory=dict) + params: dict[str, Any] = field(default_factory=dict) - # Set by the Table Metaclass: - _name: t.Optional[str] = None - _table: t.Optional[t.Type[Table]] = None + ########################################################################### - # Used by Foreign Keys: - call_chain: t.List["ForeignKey"] = field(default_factory=lambda: []) - table_alias: t.Optional[str] = None + # Lets you to map a column to a database column with a different name. + _db_column_name: Optional[str] = None + + @property + def db_column_name(self) -> str: + return self._db_column_name or self.name + + @db_column_name.setter + def db_column_name(self, value: str): + self._db_column_name = value + + ########################################################################### + + # Set by the Table Metaclass: + _name: Optional[str] = None + _table: Optional[type[Table]] = None @property def name(self) -> str: @@ -153,13 +205,24 @@ def name(self, value: str): self._name = value @property - def table(self) -> t.Type[Table]: + def table(self) -> type[Table]: if not self._table: raise ValueError( "`_table` isn't defined - the Table Metaclass should set it." ) return self._table + @table.setter + def table(self, value: type[Table]): + self._table = value + + ########################################################################### + + # Used by Foreign Keys: + call_chain: list["ForeignKey"] = field(default_factory=list) + + ########################################################################### + @property def engine_type(self) -> str: engine = self.table._meta.db @@ -168,48 +231,103 @@ def engine_type(self) -> str: else: raise ValueError("The table has no engine defined.") - def get_choices_dict(self) -> t.Optional[t.Dict[str, t.Any]]: + def get_choices_dict(self) -> Optional[dict[str, Any]]: """ Return the choices Enum as a dict. It maps the attribute name to a dict containing the display name, and value. """ if self.choices is None: return None - else: - output = {} - for element in self.choices: - if isinstance(element.value, Choice): - display_name = element.value.display_name - value = element.value.value - else: - display_name = element.name.replace("_", " ").title() - value = element.value + output = {} + for element in self.choices: + if isinstance(element.value, Choice): + display_name = element.value.display_name + value = element.value.value + else: + display_name = element.name.replace("_", " ").title() + value = element.value + + output[element.name] = { + "display_name": display_name, + "value": value, + } + + return output + + ########################################################################### + + def get_default_alias(self): + column_name = self.db_column_name + + if self.call_chain: + column_name = ( + ".".join( + cast(str, i._meta.db_column_name) for i in self.call_chain + ) + + f".{column_name}" + ) - output[element.name] = { - "display_name": display_name, - "value": value, - } + return column_name - return output + def _get_path(self, include_quotes: bool = False): + column_name = self.db_column_name - def get_full_name(self, just_alias=False) -> str: + if self.call_chain: + table_alias = self.call_chain[-1].table_alias + if include_quotes: + return f'"{table_alias}"."{column_name}"' + else: + return f"{table_alias}.{column_name}" + else: + if include_quotes: + return f'"{self.table._meta.tablename}"."{column_name}"' + else: + return f"{self.table._meta.tablename}.{column_name}" + + def get_full_name( + self, + with_alias: bool = True, + include_quotes: bool = True, + ) -> str: """ Returns the full column name, taking into account joins. + + :param with_alias: + Examples: + + .. code-block python:: + + >>> Band.manager.name._meta.get_full_name(with_alias=False) + 'band$manager.name' + + >>> Band.manager.name._meta.get_full_name(with_alias=True) + 'band$manager.name AS "manager.name"' + + :param include_quotes: + If you're using the name in a SQL query, each component needs to be + surrounded by double quotes, in case the table or column name + clashes with a reserved SQL keyword (for example, a column called + ``order``). + + .. code-block python:: + + >>> column._meta.get_full_name(include_quotes=True) + '"my_table_name"."my_column_name"' + + >>> column._meta.get_full_name(include_quotes=False) + 'my_table_name.my_column_name' + """ - column_name = self.name + full_name = self._get_path(include_quotes=include_quotes) - if not self.call_chain: - return f"{self.table._meta.tablename}.{column_name}" + if with_alias: + alias = self.get_default_alias() + if include_quotes: + full_name += f' AS "{alias}"' + else: + full_name += f" AS {alias}" - column_name = ( - "$".join([i._meta.name for i in self.call_chain]) - + f"${column_name}" - ) - alias = f"{self.call_chain[-1]._meta.table_alias}.{self.name}" - if just_alias: - return alias - else: - return f'{alias} AS "{column_name}"' + return full_name ########################################################################### @@ -219,6 +337,16 @@ def copy(self) -> ColumnMeta: params=self.params.copy(), call_chain=self.call_chain.copy(), ) + + # Make sure we don't accidentally include any other attributes which + # aren't supported by the constructor. + field_names = [i.name for i in fields(self.__class__)] + kwargs = { + kwarg: value + for kwarg, value in kwargs.items() + if kwarg in field_names + } + return self.__class__(**kwargs) def __copy__(self) -> ColumnMeta: @@ -232,24 +360,23 @@ def __deepcopy__(self, memo) -> ColumnMeta: return self.copy() -class Selectable(metaclass=ABCMeta): - alias: t.Optional[str] - - @abstractmethod - def get_select_string(self, engine_type: str, just_alias=False) -> str: - """ - In a query, what to output after the select statement - could be a - column name, a sub query, a function etc. For a column it will be the - column name. - """ - pass +class ColumnKwargs(TypedDict, total=False): + """ + Additional arguments which can be passed to :class:`Column` from + subclasses. + """ - def as_alias(self, alias: str) -> Selectable: - """ - Allows column names to be changed in the result of a select. - """ - self.alias = alias - return self + null: bool + primary_key: bool + unique: bool + index: bool + index_method: IndexMethod + required: bool + help_text: Optional[str] + choices: Optional[type[Enum]] + db_column_name: Optional[str] + secret: bool + auto_update: Any class Column(Selectable): @@ -268,14 +395,15 @@ class Column(Selectable): The column value to use if not specified by the user. :param unique: - If set, a unique contraint will be added to the column. + If set, a unique constraint will be added to the column. :param index: Whether an index is created for the column, which can improve the speed of selects, but can slow down inserts. :param index_method: - If index is set to True, this specifies what type of index is created. + If index is set to ``True``, this specifies what type of index is + created. :param required: This isn't used by the database - it's to indicate to other tools that @@ -284,14 +412,74 @@ class Column(Selectable): :param help_text: This provides some context about what the column is being used for. For - example, for a `Decimal` column called `value`, it could say - 'The units are millions of dollars'. The database doesn't use this + example, for a ``Decimal`` column called ``value``, it could say + ``'The units are millions of dollars'``. The database doesn't use this value, but tools such as Piccolo Admin use it to show a tooltip in the GUI. + :param choices: + An optional Enum - when specified, other tools such as Piccolo Admin + will render the available options in the GUI. + + :param db_column_name: + If specified, you can override the name used for the column in the + database. The main reason for this is when using a legacy database, + with a problematic column name (for example ``'class'``, which is a + reserved Python keyword). Here's an example: + + .. code-block:: python + + class MyTable(Table): + class_ = Varchar(db_column_name="class") + + >>> await MyTable.select(MyTable.class_) + [{'id': 1, 'class': 'test'}] + + This is an advanced feature which you should only need in niche + situations. + + :param secret: + If ``secret=True`` is specified, it allows a user to automatically + omit any fields when doing a select query, to help prevent + inadvertent leakage of sensitive data. + + .. code-block:: python + + class Band(Table): + name = Varchar() + net_worth = Integer(secret=True) + + >>> await Band.select(exclude_secrets=True) + [{'name': 'Pythonistas'}] + + :param auto_update: + Allows you to specify a value to set this column to each time it is + updated (via ``MyTable.update``, or ``MyTable.save`` on an existing + row). A common use case is having a ``modified_on`` column. + + .. code-block:: python + + class Band(Table): + name = Varchar() + popularity = Integer() + # The value can be a function or static value: + modified_on = Timestamp(auto_update=datetime.datetime.now) + + # This will automatically set the `modified_on` column to the + # current timestamp, without having to explicitly set it: + >>> await Band.update({ + ... Band.popularity: Band.popularity + 100 + ... }).where(Band.name == 'Pythonistas') + + Note - this feature is implemented purely within the ORM. If you want + similar functionality on the database level (i.e. if you plan on using + raw SQL to perform updates), then you may be better off creating SQL + triggers instead. + """ - value_type: t.Type = int + value_type: type = int + default: Any def __init__( self, @@ -301,19 +489,21 @@ def __init__( index: bool = False, index_method: IndexMethod = IndexMethod.btree, required: bool = False, - help_text: t.Optional[str] = None, - choices: t.Optional[t.Type[Enum]] = None, + help_text: Optional[str] = None, + choices: Optional[type[Enum]] = None, + db_column_name: Optional[str] = None, + secret: bool = False, + auto_update: Any = ..., **kwargs, ) -> None: - # This is for backwards compatibility - originally there were two - # separate arguments `primary` and `key`, but they have now been merged - # into `primary_key`. - if (kwargs.get("primary") is True) and (kwargs.get("key") is True): + # This is for backwards compatibility - originally the `primary_key` + # argument was called `primary`. + if kwargs.get("primary") is True: primary_key = True # Used for migrations. - # We deliberately omit 'required', and 'help_text' as they don't effect - # the actual schema. + # We deliberately omit 'required', 'auto_update' and 'help_text' as + # they don't effect the actual schema. # 'choices' isn't used directly in the schema, but may be important # for data migrations. kwargs.update( @@ -324,15 +514,11 @@ def __init__( "index": index, "index_method": index_method, "choices": choices, + "db_column_name": db_column_name, + "secret": secret, } ) - if kwargs.get("default", ...) is None and not null: - raise ValueError( - "A default value of None isn't allowed if the column is " - "not nullable." - ) - if choices is not None: self._validate_choices(choices, allowed_type=self.value_type) @@ -346,14 +532,17 @@ def __init__( required=required, help_text=help_text, choices=choices, + _db_column_name=db_column_name, + secret=secret, + auto_update=auto_update, ) - self.alias: t.Optional[str] = None + self._alias: Optional[str] = None def _validate_default( self, - default: t.Any, - allowed_types: t.Iterable[t.Union[None, t.Type[t.Any]]], + default: Any, + allowed_types: Iterable[Union[None, type[Any]]], allow_recursion: bool = True, ) -> bool: """ @@ -390,11 +579,16 @@ def _validate_default( ) def _validate_choices( - self, choices: t.Type[Enum], allowed_type: t.Type[t.Any] + self, choices: type[Enum], allowed_type: type[Any] ) -> bool: """ Make sure the choices value has values of the allowed_type. """ + if getattr(self, "_validated_choices", None): + # If it has previously been validated by a subclass, don't + # validate again. + return True + for element in choices: if isinstance(element.value, allowed_type): continue @@ -407,42 +601,79 @@ def _validate_choices( f"{element.name} doesn't have the correct type" ) + self._validated_choices = True + return True - def is_in(self, values: t.List[t.Any]) -> Where: - if len(values) == 0: - raise ValueError( - "The `values` list argument must contain at least one value." - ) + def is_in(self, values: Union[Select, QueryString, list[Any]]) -> Where: + from piccolo.query.methods.select import Select + + if isinstance(values, list): + if len(values) == 0: + raise ValueError( + "The `values` list argument must contain at least one " + "value." + ) + elif isinstance(values, Select): + if len(values.columns_delegate.selected_columns) != 1: + raise ValueError( + "A sub select must only return a single column." + ) + values = values.querystrings[0] + return Where(column=self, values=values, operator=In) - def not_in(self, values: t.List[t.Any]) -> Where: - if len(values) == 0: - raise ValueError( - "The `values` list argument must contain at least one value." - ) + def not_in(self, values: Union[Select, QueryString, list[Any]]) -> Where: + from piccolo.query.methods.select import Select + + if isinstance(values, list): + if len(values) == 0: + raise ValueError( + "The `values` list argument must contain at least one " + "value." + ) + elif isinstance(values, Select): + if len(values.columns_delegate.selected_columns) != 1: + raise ValueError( + "A sub select must only return a single column." + ) + values = values.querystrings[0] + return Where(column=self, values=values, operator=NotIn) def like(self, value: str) -> Where: - if "%" not in value: - raise ValueError("% is required for like operators") + """ + Both SQLite and Postgres support LIKE, but they mean different things. + + In Postgres, LIKE is case sensitive (i.e. 'foo' equals 'foo', but + 'foo' doesn't equal 'Foo'). + + In SQLite, LIKE is case insensitive for ASCII characters + (i.e. 'foo' equals 'Foo'). But not for non-ASCII characters. To learn + more, see the docs: + + https://sqlite.org/lang_expr.html#the_like_glob_regexp_and_match_operators + + """ return Where(column=self, value=value, operator=Like) def ilike(self, value: str) -> Where: - if "%" not in value: - raise ValueError("% is required for ilike operators") - if self._meta.engine_type == "postgres": - operator: t.Type[ComparisonOperator] = ILike + """ + Only Postgres supports ILIKE. It's used for case insensitive matching. + + For SQLite, it's just proxied to a LIKE query instead. + + """ + if self._meta.engine_type in ("postgres", "cockroach"): + operator: type[ComparisonOperator] = ILike else: colored_warning( - "SQLite doesn't support ILIKE currently, falling back to LIKE." + "SQLite doesn't support ILIKE, falling back to LIKE." ) operator = Like return Where(column=self, value=value, operator=operator) def not_like(self, value: str) -> Where: - if "%" not in value: - raise ValueError("% is required for like operators") return Where(column=self, value=value, operator=NotLike) def __lt__(self, value) -> Where: @@ -457,13 +688,71 @@ def __gt__(self, value) -> Where: def __ge__(self, value) -> Where: return Where(column=self, value=value, operator=GreaterEqualThan) - def __eq__(self, value) -> Where: # type: ignore + def _equals(self, column: Column, including_joins: bool = False) -> bool: + """ + We override ``__eq__``, in order to do queries such as: + + .. code-block:: python + + await Band.select().where(Band.name == 'Pythonistas') + + But this means that comparisons such as this can give unexpected + results: + + .. code-block:: python + + # We would expect the answer to be `True`, but we get `Where` + # instead: + >>> MyTable.some_column == MyTable.some_column + + + Also, column comparison is sometimes more complex than it appears. This + is why we have this custom method for comparing columns. + + Take this example: + + .. code-block:: python + + Band.manager.name == Manager.name + + They both refer to the ``name`` column on the ``Manager`` table, except + one has joins and the other doesn't. + + :param including_joins: + If ``True``, then we check if the columns are the same, as well as + their joins, i.e. ``Band.manager.name`` != ``Manager.name``. + + """ + if isinstance(column, Column): + if ( + self._meta.name == column._meta.name + and self._meta.table._meta.tablename + == column._meta.table._meta.tablename + ): + if including_joins: + if len(column._meta.call_chain) == len( + self._meta.call_chain + ): + return all( + column_a._equals(column_b, including_joins=False) + for column_a, column_b in zip( + column._meta.call_chain, + self._meta.call_chain, + ) + ) + + else: + return True + + return False + + def __eq__(self, value) -> Where: # type: ignore[override] if value is None: return Where(column=self, operator=IsNull) else: return Where(column=self, value=value, operator=Equal) - def __ne__(self, value) -> Where: # type: ignore + def __ne__(self, value) -> Where: # type: ignore[override] if value is None: return Where(column=self, operator=IsNotNull) else: @@ -474,15 +763,15 @@ def __hash__(self): def is_null(self) -> Where: """ - Can be used instead of `MyTable.column != None`, because some linters - don't like a comparison to None. + Can be used instead of ``MyTable.column == None``, because some linters + don't like a comparison to ``None``. """ return Where(column=self, operator=IsNull) def is_not_null(self) -> Where: """ - Can be used instead of `MyTable.column == None`, because some linters - don't like a comparison to None. + Can be used instead of ``MyTable.column != None``, because some linters + don't like a comparison to ``None``. """ return Where(column=self, operator=IsNotNull) @@ -492,15 +781,62 @@ def as_alias(self, name: str) -> Column: For example: - >>> await Band.select(Band.name.as_alias('title')).run() - {'title': 'Pythonistas'} + .. code-block:: python + + >>> await Band.select(Band.name.as_alias('title')).run() + {'title': 'Pythonistas'} """ column = copy.deepcopy(self) - column.alias = name + column._alias = name return column - def get_default_value(self) -> t.Any: + def join_on(self, column: Column) -> ForeignKey: + """ + Joins are typically performed via foreign key columns. For example, + here we get the band's name and the manager's name:: + + class Manager(Table): + name = Varchar() + + class Band(Table): + name = Varchar() + manager = ForeignKey(Manager) + + >>> await Band.select(Band.name, Band.manager.name) + + The ``join_on`` method lets you join tables even when foreign keys + don't exist, by joining on a column in another table. + + For example, here we want to get the manager's email, but no foreign + key exists:: + + class Manager(Table): + name = Varchar(unique=True) + email = Varchar() + + class Band(Table): + name = Varchar() + manager_name = Varchar() + + >>> await Band.select( + ... Band.name, + ... Band.manager_name.join_on(Manager.name).email + ... ) + + """ + from piccolo.columns.column_types import ForeignKey + + virtual_foreign_key = ForeignKey( + references=column._meta.table, target_column=column + ) + virtual_foreign_key._meta._name = self._meta.name + virtual_foreign_key._meta.call_chain = [*self._meta.call_chain] + virtual_foreign_key._meta._table = self._meta.table + virtual_foreign_key.set_proxy_columns() + return virtual_foreign_key + + def get_default_value(self) -> Any: """ If the column has a default attribute, return it. If it's callable, return the response instead. @@ -509,24 +845,46 @@ def get_default_value(self) -> t.Any: if default is not ...: default = default.value if isinstance(default, Enum) else default is_callable = hasattr(default, "__call__") - value = default() if is_callable else default - return value + return default() if is_callable else default # type: ignore return None - def get_select_string(self, engine_type: str, just_alias=False) -> str: + def get_select_string( + self, engine_type: str, with_alias: bool = True + ) -> QueryString: """ - How to refer to this column in a SQL query. + How to refer to this column in a SQL query, taking account of any joins + and aliases. """ - if self.alias is None: - return self._meta.get_full_name(just_alias=just_alias) - else: - original_name = self._meta.get_full_name(just_alias=True) - return f"{original_name} AS {self.alias}" - def get_where_string(self, engine_type: str) -> str: - return self.get_select_string(engine_type=engine_type, just_alias=True) + if with_alias: + if self._alias: + original_name = self._meta.get_full_name( + with_alias=False, + ) + return QueryString(f'{original_name} AS "{self._alias}"') + else: + return QueryString( + self._meta.get_full_name( + with_alias=True, + ) + ) + + return QueryString( + self._meta.get_full_name( + with_alias=False, + ) + ) - def get_sql_value(self, value: t.Any) -> t.Any: + def get_where_string(self, engine_type: str) -> QueryString: + return self.get_select_string( + engine_type=engine_type, with_alias=False + ) + + def get_sql_value( + self, + value: Any, + delimiter: str = "'", + ) -> str: """ When using DDL statements, we can't parameterise the values. An example is when setting the default for a column. So we have to convert from @@ -535,54 +893,89 @@ def get_sql_value(self, value: t.Any) -> t.Any: :param value: The Python value to convert to a string usable in a DDL statement - e.g. 1. + e.g. ``1``. + :param delimiter: + The string returned by this function is wrapped in delimiters, + ready to be added to a DDL statement. For example: + ``'hello world'``. :returns: - The string usable in the DDL statement e.g. '1'. + The string usable in the DDL statement e.g. ``'1'``. """ + from piccolo.engine.sqlite import ADAPTERS as sqlite_adapters + + # Common across all DB engines if isinstance(value, Default): - output = getattr(value, self._meta.engine_type) + return getattr(value, self._meta.engine_type) elif value is None: - output = "null" + return "null" elif isinstance(value, (float, decimal.Decimal)): - output = str(value) + return str(value) elif isinstance(value, str): - output = f"'{value}'" + return f"{delimiter}{value}{delimiter}" elif isinstance(value, bool): - output = str(value).lower() - elif isinstance(value, datetime.datetime): - output = f"'{value.isoformat().replace('T', ' ')}'" - elif isinstance(value, datetime.date): - output = f"'{value.isoformat()}'" - elif isinstance(value, datetime.time): - output = f"'{value.isoformat()}'" - elif isinstance(value, datetime.timedelta): - interval = IntervalCustom.from_timedelta(value) - output = getattr(interval, self._meta.engine_type) + return str(value).lower() elif isinstance(value, bytes): - output = f"'{value.hex()}'" - elif isinstance(value, uuid.UUID): - output = f"'{value}'" - elif isinstance(value, list): - # Convert to the array syntax. - output = ( - "'{" + ", ".join([self.get_sql_value(i) for i in value]) + "}'" - ) - else: - output = value + return f"{delimiter}{value.hex()}{delimiter}" + + # SQLite specific + if self._meta.engine_type == "sqlite": + if adapter := sqlite_adapters.get(type(value)): + sqlite_value = adapter(value) + return ( + f"{delimiter}{sqlite_value}{delimiter}" + if isinstance(sqlite_value, str) + else sqlite_value + ) - return output + # Postgres and Cockroach + if self._meta.engine_type in ["postgres", "cockroach"]: + if isinstance(value, datetime.datetime): + return f"{delimiter}{value.isoformat().replace('T', ' ')}{delimiter}" # noqa: E501 + elif isinstance(value, datetime.date): + return f"{delimiter}{value.isoformat()}{delimiter}" + elif isinstance(value, datetime.time): + return f"{delimiter}{value.isoformat()}{delimiter}" + elif isinstance(value, datetime.timedelta): + interval = IntervalCustom.from_timedelta(value) + return getattr(interval, self._meta.engine_type) + elif isinstance(value, uuid.UUID): + return f"{delimiter}{value}{delimiter}" + elif isinstance(value, list): + # Convert to the array syntax. + return ( + delimiter + + "{" + + ",".join( + self.get_sql_value( + i, + delimiter="" if isinstance(i, list) else '"', + ) + for i in value + ) + + "}" + + delimiter + ) + + return str(value) @property def column_type(self): return self.__class__.__name__.upper() @property - def querystring(self) -> QueryString: + def table_alias(self) -> str: + return "$".join( + f"{_key._meta.table._meta.tablename}${_key._meta.name}" + for _key in [*self._meta.call_chain, self] + ) + + @property + def ddl(self) -> str: """ Used when creating tables. """ - query = f'"{self._meta.name}" {self.column_type}' + query = f'"{self._meta.db_column_name}" {self.column_type}' if self._meta.primary_key: query += " PRIMARY KEY" if self._meta.unique: @@ -590,37 +983,38 @@ def querystring(self) -> QueryString: if not self._meta.null: query += " NOT NULL" - foreign_key_meta: t.Optional[ForeignKeyMeta] = getattr( - self, "_foreign_key_meta", None + foreign_key_meta = cast( + Optional[ForeignKeyMeta], + getattr(self, "_foreign_key_meta", None), ) if foreign_key_meta: references = foreign_key_meta.resolved_references - tablename = references._meta.tablename + tablename = references._meta.get_formatted_tablename() on_delete = foreign_key_meta.on_delete.value on_update = foreign_key_meta.on_update.value - primary_key_name = references._meta.primary_key._meta.name + target_column_name = ( + foreign_key_meta.resolved_target_column._meta.name + ) query += ( - f" REFERENCES {tablename} ({primary_key_name})" + f" REFERENCES {tablename} ({target_column_name})" f" ON DELETE {on_delete}" f" ON UPDATE {on_update}" ) - if not self._meta.primary_key: + # Always ran for Cockroach because unique_rowid() is directly + # defined for Cockroach Serial and BigSerial. + # Postgres and SQLite will not run this for Serial and BigSerial. + if self._meta.engine_type in ( + "cockroach" + ) or self.__class__.__name__ not in ("Serial", "BigSerial"): default = self.get_default_value() sql_value = self.get_sql_value(value=default) - # Escape the value if it contains a pair of curly braces, otherwise - # an empty value will appear in the compiled querystring. - sql_value = ( - sql_value.replace("{}", "{{}}") - if isinstance(sql_value, str) - else sql_value - ) query += f" DEFAULT {sql_value}" - return QueryString(query) + return query - def copy(self) -> Column: - column: Column = copy.copy(self) + def copy(self: Self) -> Self: + column = copy.copy(self) column._meta = self._meta.copy() return column @@ -632,7 +1026,7 @@ def __deepcopy__(self, memo) -> Column: return self.copy() def __str__(self): - return self.querystring.__str__() + return self.ddl.__str__() def __repr__(self): try: @@ -645,3 +1039,6 @@ def __repr__(self): f"{table_class_name}.{self._meta.name} - " f"{self.__class__.__name__}" ) + + +Self = TypeVar("Self", bound=Column) diff --git a/piccolo/columns/choices.py b/piccolo/columns/choices.py index d3facf4d2..2886ad94f 100644 --- a/piccolo/columns/choices.py +++ b/piccolo/columns/choices.py @@ -1,7 +1,7 @@ from __future__ import annotations -import typing as t from dataclasses import dataclass +from typing import Any @dataclass @@ -29,5 +29,5 @@ class Title(Enum): """ - value: t.Any + value: Any display_name: str diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index ba90aafe1..8df00d130 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -1,14 +1,60 @@ +""" +Notes for devs +============== + +Descriptors +----------- + +Each column type implements the descriptor protocol (the ``__get__`` and +``__set__`` methods). + +This is to signal to MyPy that the following is allowed: + +.. code-block:: python + + class Band(Table): + name = Varchar() + + band = Band() + band.name = 'Pythonistas' # Without descriptors, this would be an error + +In the above example, descriptors allow us to tell MyPy that ``name`` is a +``Varchar`` when accessed on a class, but is a ``str`` when accessed on a class +instance. + +""" + from __future__ import annotations import copy import decimal import inspect -import typing as t import uuid +from collections.abc import Callable +from dataclasses import dataclass from datetime import date, datetime, time, timedelta from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Literal, + Optional, + Union, + cast, + overload, +) -from piccolo.columns.base import Column, ForeignKeyMeta, OnDelete, OnUpdate +from typing_extensions import Unpack + +from piccolo.columns.base import ( + Column, + ColumnKwargs, + ForeignKeyMeta, + OnDelete, + OnUpdate, + ReferencedTable, +) from piccolo.columns.combination import Where from piccolo.columns.defaults.date import DateArg, DateCustom, DateNow from piccolo.columns.defaults.interval import IntervalArg, IntervalCustom @@ -24,15 +70,24 @@ TimestamptzNow, ) from piccolo.columns.defaults.uuid import UUID4, UUIDArg -from piccolo.columns.operators.comparison import ArrayAll, ArrayAny -from piccolo.columns.operators.string import ConcatPostgres, ConcatSQLite +from piccolo.columns.operators.comparison import ( + ArrayAll, + ArrayAny, + ArrayNotAny, +) +from piccolo.columns.operators.string import Concat from piccolo.columns.reference import LazyTableReference -from piccolo.querystring import QueryString, Unquoted +from piccolo.querystring import QueryString from piccolo.utils.encoding import dump_json from piccolo.utils.warnings import colored_warning -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from piccolo.columns.base import ColumnMeta + from piccolo.query.functions.array import ArrayItemType, ArrayType + from piccolo.query.operators.json import ( + GetChildElement, + GetElementFromPath, + ) from piccolo.table import Table @@ -41,71 +96,61 @@ class ConcatDelegate: """ - Used in update queries to concatenate two strings - for example: + Used in update queries to concatenate two strings - for example:: + + await Band.update({Band.name: Band.name + 'abc'}) - await Band.update({Band.name: Band.name + 'abc'}).run() """ def get_querystring( self, - column_name: str, - value: t.Union[str, Varchar, Text], - engine_type: str, - reverse=False, + column: Column, + value: Union[str, Column, QueryString], + reverse: bool = False, ) -> QueryString: - Concat = ConcatPostgres if engine_type == "postgres" else ConcatSQLite + """ + :param reverse: + By default the value is appended to the column's value. If + ``reverse=True`` then the value is prepended to the column's + value instead. - if isinstance(value, (Varchar, Text)): - column: Column = value + """ + if isinstance(value, Column): if len(column._meta.call_chain) > 0: raise ValueError( "Adding values across joins isn't currently supported." ) - other_column_name = column._meta.name - if reverse: - return QueryString( - Concat.template.format( - value_1=other_column_name, value_2=column_name - ) - ) - else: - return QueryString( - Concat.template.format( - value_1=column_name, value_2=other_column_name - ) - ) elif isinstance(value, str): - if reverse: - value_1 = QueryString("CAST({} AS text)", value) - return QueryString( - Concat.template.format(value_1="{}", value_2=column_name), - value_1, - ) - else: - value_2 = QueryString("CAST({} AS text)", value) - return QueryString( - Concat.template.format(value_1=column_name, value_2="{}"), - value_2, - ) - else: + value = QueryString("CAST({} AS TEXT)", value) + elif not isinstance(value, QueryString): raise ValueError( - "Only str, Varchar columns, and Text columns can be added." + "Only str, Column and QueryString values can be added." ) + args = [value, column] if reverse else [column, value] + + # We use the concat operator instead of the concat function, because + # this is what we historically used, and they treat null values + # differently. + return QueryString( + Concat.template.format(value_1="{}", value_2="{}"), *args + ) + class MathDelegate: """ - Used in update queries to perform math operations on columns, for example: + Used in update queries to perform math operations on columns, for example:: + + await Band.update({Band.popularity: Band.popularity + 100}) - await Band.update({Band.popularity: Band.popularity + 100}).run() """ def get_querystring( self, column_name: str, - operator: str, - value: t.Union[int, float, Integer], - reverse=False, + operator: Literal["+", "-", "/", "*"], + value: Union[int, float, Integer], + reverse: bool = False, ) -> QueryString: if isinstance(value, Integer): column: Integer = value @@ -113,11 +158,8 @@ def get_querystring( raise ValueError( "Adding values across joins isn't currently supported." ) - column_name = column._meta.name - if reverse: - return QueryString(f"{column_name} {operator} {column_name}") - else: - return QueryString(f"{column_name} {operator} {column_name}") + other_column_name = value._meta.db_column_name + return QueryString(f"{column_name} {operator} {other_column_name}") elif isinstance(value, (int, float)): if reverse: return QueryString(f"{{}} {operator} {column_name}", value) @@ -130,6 +172,131 @@ def get_querystring( ) +class TimedeltaDelegate: + """ + Used in update queries to add a timedelta to these columns: + + * ``Timestamp`` + * ``Timestamptz`` + * ``Date`` + * ``Interval`` + + Example:: + + class Concert(Table): + starts = Timestamp() + + # Lets us increase all of the matching values by 1 day: + >>> await Concert.update({ + ... Concert.starts: Concert.starts + datetime.timedelta(days=1) + ... }) + + """ + + # Maps the attribute name in Python's timedelta to what it's called in + # Postgres. + postgres_attr_map: dict[str, str] = { + "days": "DAYS", + "seconds": "SECONDS", + "microseconds": "MICROSECONDS", + } + + def get_postgres_interval_string(self, interval: timedelta) -> str: + """ + :returns: + A string like:: + + "'1 DAYS 5 SECONDS 1000 MICROSECONDS'" + + """ + output = [] + for timedelta_key, postgres_name in self.postgres_attr_map.items(): + timestamp_value = getattr(interval, timedelta_key) + if timestamp_value: + output.append(f"{timestamp_value} {postgres_name}") + + output_string = " ".join(output) + return f"'{output_string}'" + + def get_sqlite_interval_string(self, interval: timedelta) -> str: + """ + :returns: + A string like:: + + "'+1 DAYS', '+5.001 SECONDS'" + + """ + output = [] + + data = { + "DAYS": interval.days, + "SECONDS": interval.seconds + (interval.microseconds / 10**6), + } + + for key, value in data.items(): + if value: + operator = "+" if value >= 0 else "" + output.append(f"'{operator}{value} {key}'") + + output_string = ", ".join(output) + return output_string + + def get_querystring( + self, + column: Column, + operator: Literal["+", "-"], + value: timedelta, + engine_type: str, + ) -> QueryString: + column_name = column._meta.name + + if not isinstance(value, timedelta): + raise ValueError("Only timedelta values can be added.") + + if engine_type in ("postgres", "cockroach"): + value_string = self.get_postgres_interval_string(interval=value) + return QueryString( + f'"{column_name}" {operator} INTERVAL {value_string}', + ) + elif engine_type == "sqlite": + if isinstance(column, Interval): + # SQLite doesn't have a proper Interval type. Instead we store + # the number of seconds. + return QueryString( + f'CAST("{column_name}" AS REAL) {operator} {value.total_seconds()}' # noqa: E501 + ) + elif isinstance(column, (Timestamp, Timestamptz)): + if ( + round(value.microseconds / 1000) * 1000 + != value.microseconds + ): + raise ValueError( + "timedeltas with such high precision won't save " + "accurately - the max resolution is 1 millisecond." + ) + strftime_format = "%Y-%m-%d %H:%M:%f" + elif isinstance(column, Date): + strftime_format = "%Y-%m-%d" + else: + raise ValueError( + f"{column.__class__.__name__} doesn't support timedelta " + "addition currently." + ) + + if operator == "-": + value = value * -1 + + value_string = self.get_sqlite_interval_string(interval=value) + + # We use `strftime` instead of `datetime`, because `datetime` + # doesn't return microseconds. + return QueryString( + f"strftime('{strftime_format}', \"{column_name}\", {value_string})" # noqa: E501 + ) + else: + raise ValueError("Unrecognised engine") + + ############################################################################### @@ -146,10 +313,10 @@ class Band(Table): name = Varchar(length=100) # Create - >>> Band(name='Pythonistas').save().run_sync() + >>> await Band(name='Pythonistas').save() # Query - >>> Band.select(Band.name).run_sync() + >>> await Band.select(Band.name) {'name': 'Pythonistas'} :param length: @@ -162,68 +329,87 @@ class Band(Table): def __init__( self, - length: int = 255, - default: t.Union[str, Enum, t.Callable[[], str], None] = "", - **kwargs, + length: Optional[int] = 255, + default: Union[str, Enum, Callable[[], str], None] = "", + **kwargs: Unpack[ColumnKwargs], ) -> None: self._validate_default(default, (str, None)) self.length = length self.default = default - kwargs.update({"length": length, "default": default}) - super().__init__(**kwargs) + super().__init__(length=length, default=default, **kwargs) @property def column_type(self): - if self.length: - return f"VARCHAR({self.length})" - else: - return "VARCHAR" + return f"VARCHAR({self.length})" if self.length else "VARCHAR" - def __add__(self, value: t.Union[str, Varchar, Text]) -> QueryString: - engine_type = self._meta.table._meta.db.engine_type + ########################################################################### + # For update queries + + def __add__(self, value: Union[str, Varchar, Text]) -> QueryString: return self.concat_delegate.get_querystring( - column_name=self._meta.name, + column=self, value=value, - engine_type=engine_type, ) - def __radd__(self, value: t.Union[str, Varchar, Text]) -> QueryString: - engine_type = self._meta.table._meta.db.engine_type + def __radd__(self, value: Union[str, Varchar, Text]) -> QueryString: return self.concat_delegate.get_querystring( - column_name=self._meta.name, + column=self, value=value, - engine_type=engine_type, reverse=True, ) + ########################################################################### + # Descriptors -class Secret(Varchar): - """ - The database treats it the same as a ``Varchar``, but Piccolo may treat it - differently internally - for example, allowing a user to automatically - omit any secret fields when doing a select query, to help prevent - inadvertant leakage. A common use for a ``Secret`` field is a password. + @overload + def __get__(self, obj: Table, objtype=None) -> str: ... - Uses the ``str`` type for values. + @overload + def __get__(self, obj: None, objtype=None) -> Varchar: ... - **Example** + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self - .. code-block:: python + def __set__(self, obj, value: Union[str, None]): + obj.__dict__[self._meta.name] = value - class Door(Table): - code = Secret(length=100) - # Create - >>> Door(code='123abc').save().run_sync() +class Email(Varchar): + """ + Used for storing email addresses. It's identical to :class:`Varchar`, + except when using :func:`create_pydantic_model ` - + we add email validation to the Pydantic model. This means that :ref:`PiccoloAdmin` + also validates email addresses. + """ # noqa: E501 + + pass - # Query - >>> Door.select(Door.code).run_sync() - {'code': '123abc'} +class Secret(Varchar): + """ + This is just an alias to ``Varchar(secret=True)``. It's here for backwards + compatibility. """ - pass + def __init__(self, *args, **kwargs): + kwargs["secret"] = True + super().__init__(*args, **kwargs) + + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> str: ... + + @overload + def __get__(self, obj: None, objtype=None) -> Secret: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[str, None]): + obj.__dict__[self._meta.name] = value class Text(Column): @@ -239,10 +425,10 @@ class Band(Table): name = Text() # Create - >>> Band(name='Pythonistas').save().run_sync() + >>> await Band(name='Pythonistas').save() # Query - >>> Band.select(Band.name).run_sync() + >>> await Band.select(Band.name) {'name': 'Pythonistas'} """ @@ -252,29 +438,44 @@ class Band(Table): def __init__( self, - default: t.Union[str, Enum, None, t.Callable[[], str]] = "", - **kwargs, + default: Union[str, Enum, None, Callable[[], str]] = "", + **kwargs: Unpack[ColumnKwargs], ) -> None: self._validate_default(default, (str, None)) self.default = default - kwargs.update({"default": default}) - super().__init__(**kwargs) + super().__init__(default=default, **kwargs) - def __add__(self, value: t.Union[str, Varchar, Text]) -> QueryString: - engine_type = self._meta.table._meta.db.engine_type + ########################################################################### + # For update queries + + def __add__(self, value: Union[str, Varchar, Text]) -> QueryString: return self.concat_delegate.get_querystring( - column_name=self._meta.name, value=value, engine_type=engine_type + column=self, + value=value, ) - def __radd__(self, value: t.Union[str, Varchar, Text]) -> QueryString: - engine_type = self._meta.table._meta.db.engine_type + def __radd__(self, value: Union[str, Varchar, Text]) -> QueryString: return self.concat_delegate.get_querystring( - column_name=self._meta.name, + column=self, value=value, - engine_type=engine_type, reverse=True, ) + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> str: ... + + @overload + def __get__(self, obj: None, objtype=None) -> Text: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[str, None]): + obj.__dict__[self._meta.name] = value + class UUID(Column): """ @@ -291,17 +492,21 @@ class Band(Table): uuid = UUID() # Create - >>> DiscountCode(code=uuid.uuid4()).save().run_sync() + >>> await DiscountCode(code=uuid.uuid4()).save() # Query - >>> DiscountCode.select(DiscountCode.code).run_sync() + >>> await DiscountCode.select(DiscountCode.code) {'code': UUID('09c4c17d-af68-4ce7-9955-73dcd892e462')} """ value_type = uuid.UUID - def __init__(self, default: UUIDArg = UUID4(), **kwargs) -> None: + def __init__( + self, + default: UUIDArg = UUID4(), + **kwargs: Unpack[ColumnKwargs], + ) -> None: if default is UUID4: # In case the class is passed in, instead of an instance. default = UUID4() @@ -314,14 +519,28 @@ def __init__(self, default: UUIDArg = UUID4(), **kwargs) -> None: if isinstance(default, str): try: default = uuid.UUID(default) - except ValueError: + except ValueError as e: raise ValueError( "The default is a string, but not a valid uuid." - ) + ) from e self.default = default - kwargs.update({"default": default}) - super().__init__(**kwargs) + super().__init__(default=default, **kwargs) + + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> uuid.UUID: ... + + @overload + def __get__(self, obj: None, objtype=None) -> UUID: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[uuid.UUID, None]): + obj.__dict__[self._meta.name] = value class Integer(Column): @@ -336,10 +555,10 @@ class Band(Table): popularity = Integer() # Create - >>> Band(popularity=1000).save().run_sync() + >>> await Band(popularity=1000).save() # Query - >>> Band.select(Band.popularity).run_sync() + >>> await Band.select(Band.popularity) {'popularity': 1000} """ @@ -348,81 +567,96 @@ class Band(Table): def __init__( self, - default: t.Union[int, Enum, t.Callable[[], int], None] = 0, - **kwargs, + default: Union[int, Enum, Callable[[], int], None] = 0, + **kwargs: Unpack[ColumnKwargs], ) -> None: self._validate_default(default, (int, None)) self.default = default - kwargs.update({"default": default}) - super().__init__(**kwargs) + super().__init__(default=default, **kwargs) + + ########################################################################### + # For update queries - def __add__(self, value: t.Union[int, float, Integer]) -> QueryString: + def __add__(self, value: Union[int, float, Integer]) -> QueryString: return self.math_delegate.get_querystring( - column_name=self._meta.name, operator="+", value=value + column_name=self._meta.db_column_name, operator="+", value=value ) - def __radd__(self, value: t.Union[int, float, Integer]) -> QueryString: + def __radd__(self, value: Union[int, float, Integer]) -> QueryString: return self.math_delegate.get_querystring( - column_name=self._meta.name, + column_name=self._meta.db_column_name, operator="+", value=value, reverse=True, ) - def __sub__(self, value: t.Union[int, float, Integer]) -> QueryString: + def __sub__(self, value: Union[int, float, Integer]) -> QueryString: return self.math_delegate.get_querystring( - column_name=self._meta.name, operator="-", value=value + column_name=self._meta.db_column_name, operator="-", value=value ) - def __rsub__(self, value: t.Union[int, float, Integer]) -> QueryString: + def __rsub__(self, value: Union[int, float, Integer]) -> QueryString: return self.math_delegate.get_querystring( - column_name=self._meta.name, + column_name=self._meta.db_column_name, operator="-", value=value, reverse=True, ) - def __mul__(self, value: t.Union[int, float, Integer]) -> QueryString: + def __mul__(self, value: Union[int, float, Integer]) -> QueryString: return self.math_delegate.get_querystring( - column_name=self._meta.name, operator="*", value=value + column_name=self._meta.db_column_name, operator="*", value=value ) - def __rmul__(self, value: t.Union[int, float, Integer]) -> QueryString: + def __rmul__(self, value: Union[int, float, Integer]) -> QueryString: return self.math_delegate.get_querystring( - column_name=self._meta.name, + column_name=self._meta.db_column_name, operator="*", value=value, reverse=True, ) - def __truediv__(self, value: t.Union[int, float, Integer]) -> QueryString: + def __truediv__(self, value: Union[int, float, Integer]) -> QueryString: return self.math_delegate.get_querystring( - column_name=self._meta.name, operator="/", value=value + column_name=self._meta.db_column_name, operator="/", value=value ) - def __rtruediv__(self, value: t.Union[int, float, Integer]) -> QueryString: + def __rtruediv__(self, value: Union[int, float, Integer]) -> QueryString: return self.math_delegate.get_querystring( - column_name=self._meta.name, + column_name=self._meta.db_column_name, operator="/", value=value, reverse=True, ) - def __floordiv__(self, value: t.Union[int, float, Integer]) -> QueryString: + def __floordiv__(self, value: Union[int, float, Integer]) -> QueryString: return self.math_delegate.get_querystring( - column_name=self._meta.name, operator="/", value=value + column_name=self._meta.db_column_name, operator="/", value=value ) - def __rfloordiv__( - self, value: t.Union[int, float, Integer] - ) -> QueryString: + def __rfloordiv__(self, value: Union[int, float, Integer]) -> QueryString: return self.math_delegate.get_querystring( - column_name=self._meta.name, + column_name=self._meta.db_column_name, operator="/", value=value, reverse=True, ) + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> int: ... + + @overload + def __get__(self, obj: None, objtype=None) -> Integer: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[int, None]): + obj.__dict__[self._meta.name] = value + ############################################################################### # BigInt and SmallInt only exist on Postgres. SQLite treats them the same as @@ -443,23 +677,42 @@ class Band(Table): value = BigInt() # Create - >>> Band(popularity=1000000).save().run_sync() + >>> await Band(popularity=1000000).save() # Query - >>> Band.select(Band.popularity).run_sync() + >>> await Band.select(Band.popularity) {'popularity': 1000000} """ - @property - def column_type(self): - engine_type = self._meta.table._meta.db.engine_type + def _get_column_type(self, engine_type: str): if engine_type == "postgres": return "BIGINT" + elif engine_type == "cockroach": + return "BIGINT" elif engine_type == "sqlite": return "INTEGER" raise Exception("Unrecognized engine type") + @property + def column_type(self): + return self._get_column_type(engine_type=self._meta.engine_type) + + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> int: ... + + @overload + def __get__(self, obj: None, objtype=None) -> BigInt: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[int, None]): + obj.__dict__[self._meta.name] = value + class SmallInt(Integer): """ @@ -474,29 +727,46 @@ class Band(Table): value = SmallInt() # Create - >>> Band(popularity=1000).save().run_sync() + >>> await Band(popularity=1000).save() # Query - >>> Band.select(Band.popularity).run_sync() + >>> await Band.select(Band.popularity) {'popularity': 1000} """ @property def column_type(self): - engine_type = self._meta.table._meta.db.engine_type + engine_type = self._meta.engine_type if engine_type == "postgres": return "SMALLINT" + elif engine_type == "cockroach": + return "SMALLINT" elif engine_type == "sqlite": return "INTEGER" raise Exception("Unrecognized engine type") + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> int: ... + + @overload + def __get__(self, obj: None, objtype=None) -> SmallInt: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[int, None]): + obj.__dict__[self._meta.name] = value + ############################################################################### -DEFAULT = Unquoted("DEFAULT") -NULL = Unquoted("null") +DEFAULT = QueryString("DEFAULT") +NULL = QueryString("null") class Serial(Column): @@ -506,36 +776,109 @@ class Serial(Column): @property def column_type(self): - engine_type = self._meta.table._meta.db.engine_type + engine_type = self._meta.engine_type if engine_type == "postgres": return "SERIAL" + elif engine_type == "cockroach": + return "INTEGER" elif engine_type == "sqlite": return "INTEGER" raise Exception("Unrecognized engine type") - def default(self): - engine_type = self._meta.table._meta.db.engine_type + def default(self) -> QueryString: + engine_type = self._meta.engine_type + if engine_type == "postgres": return DEFAULT + elif engine_type == "cockroach": + return QueryString("unique_rowid()") elif engine_type == "sqlite": return NULL raise Exception("Unrecognized engine type") + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> int: ... + + @overload + def __get__(self, obj: None, objtype=None) -> Serial: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[int, None]): + obj.__dict__[self._meta.name] = value + + +class BigSerial(Serial): + """ + An alias to a large autoincrementing integer column in Postgres. + """ + + @property + def column_type(self): + engine_type = self._meta.engine_type + if engine_type == "postgres": + return "BIGSERIAL" + elif engine_type == "cockroach": + return "BIGINT" + elif engine_type == "sqlite": + return "INTEGER" + raise Exception("Unrecognized engine type") + + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> int: ... + + @overload + def __get__(self, obj: None, objtype=None) -> BigSerial: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[int, None]): + obj.__dict__[self._meta.name] = value + class PrimaryKey(Serial): - def __init__(self, **kwargs) -> None: + def __init__( + self, + **kwargs: Unpack[ColumnKwargs], + ) -> None: # Set the index to False, as a database should automatically create # an index for a PrimaryKey column. kwargs.update({"primary_key": True, "index": False}) colored_warning( - "`PrimaryKey` is deprecated and " - "will be removed in future versions.", + "`PrimaryKey` is deprecated and will be removed in future " + "versions. Use `UUID(primary_key=True)` or " + "`Serial(primary_key=True)` instead. If no primary key column is " + "specified, Piccolo will automatically add one for you called " + "`id`.", category=DeprecationWarning, ) super().__init__(**kwargs) + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> int: ... + + @overload + def __get__(self, obj: None, objtype=None) -> PrimaryKey: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[int, None]): + obj.__dict__[self._meta.name] = value + ############################################################################### @@ -554,20 +897,23 @@ class Concert(Table): starts = Timestamp() # Create - >>> Concert( - >>> starts=datetime.datetime(year=2050, month=1, day=1) - >>> ).save().run_sync() + >>> await Concert( + ... starts=datetime.datetime(year=2050, month=1, day=1) + ... ).save() # Query - >>> Concert.select(Concert.starts).run_sync() + >>> await Concert.select(Concert.starts) {'starts': datetime.datetime(2050, 1, 1, 0, 0)} """ value_type = datetime + timedelta_delegate = TimedeltaDelegate() def __init__( - self, default: TimestampArg = TimestampNow(), **kwargs + self, + default: TimestampArg = TimestampNow(), + **kwargs: Unpack[ColumnKwargs], ) -> None: self._validate_default(default, TimestampArg.__args__) # type: ignore @@ -583,8 +929,44 @@ def __init__( default = TimestampNow() self.default = default - kwargs.update({"default": default}) - super().__init__(**kwargs) + super().__init__(default=default, **kwargs) + + ########################################################################### + # For update queries + + def __add__(self, value: timedelta) -> QueryString: + return self.timedelta_delegate.get_querystring( + column=self, + operator="+", + value=value, + engine_type=self._meta.engine_type, + ) + + def __radd__(self, value: timedelta) -> QueryString: + return self.__add__(value) + + def __sub__(self, value: timedelta) -> QueryString: + return self.timedelta_delegate.get_querystring( + column=self, + operator="-", + value=value, + engine_type=self._meta.engine_type, + ) + + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> datetime: ... + + @overload + def __get__(self, obj: None, objtype=None) -> Timestamp: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[datetime, None]): + obj.__dict__[self._meta.name] = value class Timestamptz(Column): @@ -603,14 +985,14 @@ class Concert(Table): starts = Timestamptz() # Create - >>> Concert( - >>> starts=datetime.datetime( - >>> year=2050, month=1, day=1, tzinfo=datetime.timezone.tz - >>> ) - >>> ).save().run_sync() + >>> await Concert( + ... starts=datetime.datetime( + ... year=2050, month=1, day=1, tzinfo=datetime.timezone.tz + ... ) + ... ).save() # Query - >>> Concert.select(Concert.starts).run_sync() + >>> await Concert.select(Concert.starts) { 'starts': datetime.datetime( 2050, 1, 1, 0, 0, tzinfo=datetime.timezone.utc @@ -621,8 +1003,16 @@ class Concert(Table): value_type = datetime + # Currently just used by ModelBuilder, to know that we want a timezone + # aware datetime. + tz_aware = True + + timedelta_delegate = TimedeltaDelegate() + def __init__( - self, default: TimestamptzArg = TimestamptzNow(), **kwargs + self, + default: TimestamptzArg = TimestamptzNow(), + **kwargs: Unpack[ColumnKwargs], ) -> None: self._validate_default( default, TimestamptzArg.__args__ # type: ignore @@ -635,8 +1025,44 @@ def __init__( default = TimestamptzNow() self.default = default - kwargs.update({"default": default}) - super().__init__(**kwargs) + super().__init__(default=default, **kwargs) + + ########################################################################### + # For update queries + + def __add__(self, value: timedelta) -> QueryString: + return self.timedelta_delegate.get_querystring( + column=self, + operator="+", + value=value, + engine_type=self._meta.engine_type, + ) + + def __radd__(self, value: timedelta) -> QueryString: + return self.__add__(value) + + def __sub__(self, value: timedelta) -> QueryString: + return self.timedelta_delegate.get_querystring( + column=self, + operator="-", + value=value, + engine_type=self._meta.engine_type, + ) + + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> datetime: ... + + @overload + def __get__(self, obj: None, objtype=None) -> Timestamptz: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[datetime, None]): + obj.__dict__[self._meta.name] = value class Date(Column): @@ -653,19 +1079,24 @@ class Concert(Table): starts = Date() # Create - >>> Concert( - >>> starts=datetime.date(year=2020, month=1, day=1) - >>> ).save().run_sync() + >>> await Concert( + ... starts=datetime.date(year=2020, month=1, day=1) + ... ).save() # Query - >>> Concert.select(Concert.starts).run_sync() + >>> await Concert.select(Concert.starts) {'starts': datetime.date(2020, 1, 1)} """ value_type = date + timedelta_delegate = TimedeltaDelegate() - def __init__(self, default: DateArg = DateNow(), **kwargs) -> None: + def __init__( + self, + default: DateArg = DateNow(), + **kwargs: Unpack[ColumnKwargs], + ) -> None: self._validate_default(default, DateArg.__args__) # type: ignore if isinstance(default, date): @@ -675,8 +1106,44 @@ def __init__(self, default: DateArg = DateNow(), **kwargs) -> None: default = DateNow() self.default = default - kwargs.update({"default": default}) - super().__init__(**kwargs) + super().__init__(default=default, **kwargs) + + ########################################################################### + # For update queries + + def __add__(self, value: timedelta) -> QueryString: + return self.timedelta_delegate.get_querystring( + column=self, + operator="+", + value=value, + engine_type=self._meta.engine_type, + ) + + def __radd__(self, value: timedelta) -> QueryString: + return self.__add__(value) + + def __sub__(self, value: timedelta) -> QueryString: + return self.timedelta_delegate.get_querystring( + column=self, + operator="-", + value=value, + engine_type=self._meta.engine_type, + ) + + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> date: ... + + @overload + def __get__(self, obj: None, objtype=None) -> Date: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[date, None]): + obj.__dict__[self._meta.name] = value class Time(Column): @@ -693,36 +1160,77 @@ class Concert(Table): starts = Time() # Create - >>> Concert( - >>> starts=datetime.time(hour=20, minute=0, second=0) - >>> ).save().run_sync() + >>> await Concert( + ... starts=datetime.time(hour=20, minute=0, second=0) + ... ).save() # Query - >>> Concert.select(Concert.starts).run_sync() + >>> await Concert.select(Concert.starts) {'starts': datetime.time(20, 0, 0)} """ value_type = time + timedelta_delegate = TimedeltaDelegate() - def __init__(self, default: TimeArg = TimeNow(), **kwargs) -> None: + def __init__( + self, + default: TimeArg = TimeNow(), + **kwargs: Unpack[ColumnKwargs], + ) -> None: self._validate_default(default, TimeArg.__args__) # type: ignore if isinstance(default, time): default = TimeCustom.from_time(default) self.default = default - kwargs.update({"default": default}) - super().__init__(**kwargs) + super().__init__(default=default, **kwargs) + ########################################################################### + # For update queries -class Interval(Column): # lgtm [py/missing-equals] - """ - Used for storing timedeltas. Uses the ``timedelta`` type for values. + def __add__(self, value: timedelta) -> QueryString: + return self.timedelta_delegate.get_querystring( + column=self, + operator="+", + value=value, + engine_type=self._meta.engine_type, + ) - **Example** + def __radd__(self, value: timedelta) -> QueryString: + return self.__add__(value) - .. code-block:: python + def __sub__(self, value: timedelta) -> QueryString: + return self.timedelta_delegate.get_querystring( + column=self, + operator="-", + value=value, + engine_type=self._meta.engine_type, + ) + + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> time: ... + + @overload + def __get__(self, obj: None, objtype=None) -> Time: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[time, None]): + obj.__dict__[self._meta.name] = value + + +class Interval(Column): + """ + Used for storing timedeltas. Uses the ``timedelta`` type for values. + + **Example** + + .. code-block:: python from datetime import timedelta @@ -730,20 +1238,23 @@ class Concert(Table): duration = Interval() # Create - >>> Concert( - >>> duration=timedelta(hours=2) - >>> ).save().run_sync() + >>> await Concert( + ... duration=timedelta(hours=2) + ... ).save() # Query - >>> Concert.select(Concert.duration).run_sync() + >>> await Concert.select(Concert.duration) {'duration': datetime.timedelta(seconds=7200)} """ value_type = timedelta + timedelta_delegate = TimedeltaDelegate() def __init__( - self, default: IntervalArg = IntervalCustom(), **kwargs + self, + default: IntervalArg = IntervalCustom(), + **kwargs: Unpack[ColumnKwargs], ) -> None: self._validate_default(default, IntervalArg.__args__) # type: ignore @@ -751,28 +1262,65 @@ def __init__( default = IntervalCustom.from_timedelta(default) self.default = default - kwargs.update({"default": default}) - super().__init__(**kwargs) + super().__init__(default=default, **kwargs) @property def column_type(self): - engine_type = self._meta.table._meta.db.engine_type - if engine_type == "postgres": + engine_type = self._meta.engine_type + if engine_type in ("postgres", "cockroach"): return "INTERVAL" elif engine_type == "sqlite": # We can't use 'INTERVAL' because the type affinity in SQLite would - # make it an integer - but we need a numeric field. + # make it an integer - but we need a text field. # https://sqlite.org/datatype3.html#determination_of_column_affinity return "SECONDS" raise Exception("Unrecognized engine type") + ########################################################################### + # For update queries + + def __add__(self, value: timedelta) -> QueryString: + return self.timedelta_delegate.get_querystring( + column=self, + operator="+", + value=value, + engine_type=self._meta.engine_type, + ) + + def __radd__(self, value: timedelta) -> QueryString: + return self.__add__(value) + + def __sub__(self, value: timedelta) -> QueryString: + return self.timedelta_delegate.get_querystring( + column=self, + operator="-", + value=value, + engine_type=self._meta.engine_type, + ) + + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> timedelta: ... + + @overload + def __get__(self, obj: None, objtype=None) -> Interval: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[timedelta, None]): + obj.__dict__[self._meta.name] = value + ############################################################################### class Boolean(Column): """ - Used for storing True / False values. Uses the ``bool`` type for values. + Used for storing ``True`` / ``False`` values. Uses the ``bool`` type for + values. **Example** @@ -782,10 +1330,10 @@ class Band(Table): has_drummer = Boolean() # Create - >>> Band(has_drummer=True).save().run_sync() + >>> await Band(has_drummer=True).save() # Query - >>> Band.select(Band.has_drummer).run_sync() + >>> await Band.select(Band.has_drummer) {'has_drummer': True} """ @@ -794,13 +1342,12 @@ class Band(Table): def __init__( self, - default: t.Union[bool, Enum, t.Callable[[], bool], None] = False, - **kwargs, + default: Union[bool, Enum, Callable[[], bool], None] = False, + **kwargs: Unpack[ColumnKwargs], ) -> None: self._validate_default(default, (bool, None)) self.default = default - kwargs.update({"default": default}) - super().__init__(**kwargs) + super().__init__(default=default, **kwargs) def eq(self, value) -> Where: """ @@ -811,7 +1358,7 @@ def eq(self, value) -> Where: await MyTable.select().where( MyTable.some_boolean_column == True - ).run() + ) It's more Pythonic to use ``is True`` rather than ``== True``, which is why linters complain. The work around is to do the following instead: @@ -820,7 +1367,7 @@ def eq(self, value) -> Where: await MyTable.select().where( MyTable.some_boolean_column.__eq__(True) - ).run() + ) Using the ``__eq__`` magic method is a bit untidy, which is why this ``eq`` method exists. @@ -829,7 +1376,7 @@ def eq(self, value) -> Where: await MyTable.select().where( MyTable.some_boolean_column.eq(True) - ).run() + ) The ``ne`` method exists for the same reason, for ``!=``. @@ -842,6 +1389,21 @@ def ne(self, value) -> Where: """ return self.__ne__(value) + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> bool: ... + + @overload + def __get__(self, obj: None, objtype=None) -> Boolean: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[bool, None]): + obj.__dict__[self._meta.name] = value + ############################################################################### @@ -861,19 +1423,19 @@ class Ticket(Table): price = Numeric(digits=(5,2)) # Create - >>> Ticket(price=Decimal('50.0')).save().run_sync() + >>> await Ticket(price=Decimal('50.0')).save() # Query - >>> Ticket.select(Ticket.price).run_sync() + >>> await Ticket.select(Ticket.price) {'price': Decimal('50.0')} - :arg digits: + :param digits: When creating the column, you specify how many digits are allowed - using a tuple. The first value is the `precision`, which is the - total number of digits allowed. The second value is the `range`, + using a tuple. The first value is the ``precision``, which is the + total number of digits allowed. The second value is the ``range``, which specifies how many of those digits are after the decimal point. For example, to store monetary values up to £999.99, the - digits argument is `(5,2)`. + digits argument is ``(5,2)``. """ @@ -881,32 +1443,35 @@ class Ticket(Table): @property def column_type(self): + engine_type = self._meta.engine_type + if engine_type == "cockroach": + return "NUMERIC" # All Numeric is the same for Cockroach. if self.digits: return f"NUMERIC({self.precision}, {self.scale})" else: return "NUMERIC" @property - def precision(self): + def precision(self) -> Optional[int]: """ The total number of digits allowed. """ - return self.digits[0] + return self.digits[0] if self.digits is not None else None @property - def scale(self): + def scale(self) -> Optional[int]: """ The number of digits after the decimal point. """ - return self.digits[1] + return self.digits[1] if self.digits is not None else None def __init__( self, - digits: t.Optional[t.Tuple[int, int]] = None, - default: t.Union[ - decimal.Decimal, Enum, t.Callable[[], decimal.Decimal], None + digits: Optional[tuple[int, int]] = None, + default: Union[ + decimal.Decimal, Enum, Callable[[], decimal.Decimal], None ] = decimal.Decimal(0.0), - **kwargs, + **kwargs: Unpack[ColumnKwargs], ) -> None: if isinstance(digits, tuple): if len(digits) != 2: @@ -915,16 +1480,29 @@ def __init__( "with the first value being the precision, and the second " "value being the scale." ) - else: - if digits is not None: - raise ValueError("The digits argument should be a tuple.") + elif digits is not None: + raise ValueError("The digits argument should be a tuple.") self._validate_default(default, (decimal.Decimal, None)) self.default = default self.digits = digits - kwargs.update({"default": default, "digits": digits}) - super().__init__(**kwargs) + super().__init__(default=default, digits=digits, **kwargs) + + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> decimal.Decimal: ... + + @overload + def __get__(self, obj: None, objtype=None) -> Numeric: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[decimal.Decimal, None]): + obj.__dict__[self._meta.name] = value class Decimal(Numeric): @@ -932,7 +1510,20 @@ class Decimal(Numeric): An alias for Numeric. """ - pass + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> decimal.Decimal: ... + + @overload + def __get__(self, obj: None, objtype=None) -> Decimal: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[decimal.Decimal, None]): + obj.__dict__[self._meta.name] = value class Real(Column): @@ -948,10 +1539,10 @@ class Concert(Table): rating = Real() # Create - >>> Concert(rating=7.8).save().run_sync() + >>> await Concert(rating=7.8).save() # Query - >>> Concert.select(Concert.rating).run_sync() + >>> await Concert.select(Concert.rating) {'rating': 7.8} """ @@ -960,13 +1551,31 @@ class Concert(Table): def __init__( self, - default: t.Union[float, Enum, t.Callable[[], float], None] = 0.0, - **kwargs, + default: Union[float, Enum, Callable[[], float], None] = 0.0, + **kwargs: Unpack[ColumnKwargs], ) -> None: + if isinstance(default, int): + # For example, allow `0` as a valid default. + default = float(default) + self._validate_default(default, (float, None)) self.default = default - kwargs.update({"default": default}) - super().__init__(**kwargs) + super().__init__(default=default, **kwargs) + + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> float: ... + + @overload + def __get__(self, obj: None, objtype=None) -> Real: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[float, None]): + obj.__dict__[self._meta.name] = value class Float(Real): @@ -974,13 +1583,56 @@ class Float(Real): An alias for Real. """ - pass + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> float: ... + + @overload + def __get__(self, obj: None, objtype=None) -> Float: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[float, None]): + obj.__dict__[self._meta.name] = value + + +class DoublePrecision(Real): + """ + The same as ``Real``, except the numbers are stored with greater precision. + """ + + @property + def column_type(self): + return "DOUBLE PRECISION" + + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> float: ... + + @overload + def __get__(self, obj: None, objtype=None) -> DoublePrecision: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[float, None]): + obj.__dict__[self._meta.name] = value ############################################################################### -class ForeignKey(Column): # lgtm [py/missing-equals] +@dataclass +class ForeignKeySetupResponse: + is_lazy: bool + + +class ForeignKey(Column, Generic[ReferencedTable]): """ Used to reference another table. Uses the same type as the primary key column on the table it references. @@ -993,14 +1645,14 @@ class Band(Table): manager = ForeignKey(references=Manager) # Create - >>> Band(manager=1).save().run_sync() + >>> await Band(manager=1).save() # Query - >>> Band.select(Band.manager).run_sync() + >>> await Band.select(Band.manager) {'manager': 1} # Query object - >>> band = await Band.objects().first().run() + >>> band = await Band.objects().first() >>> band.manager 1 @@ -1010,14 +1662,14 @@ class Band(Table): .. code-block:: python - >>> await Band.select(Band.name, Band.manager.name).first().run() + >>> await Band.select(Band.name, Band.manager.name).first() {'name': 'Pythonistas', 'manager.name': 'Guido'} To retrieve all of the columns in the related table: .. code-block:: python - >>> await Band.select(Band.name, *Band.manager.all_columns()).first().run() + >>> await Band.select(Band.name, *Band.manager.all_columns()).first() {'name': 'Pythonistas', 'manager.id': 1, 'manager.name': 'Guido'} To get a referenced row as an object: @@ -1026,21 +1678,21 @@ class Band(Table): manager = await Manager.objects().where( Manager.id == some_band.manager - ).run() + ) Or use either of the following, which are just a proxy to the above: .. code-block:: python - manager = await band.get_related('manager').run() - manager = await band.get_related(Band.manager).run() + manager = await band.get_related('manager') + manager = await band.get_related(Band.manager) To change the manager: .. code-block:: python band.manager = some_manager_id - await band.save().run() + await band.save() :param references: The ``Table`` being referenced. @@ -1107,11 +1759,11 @@ class Band(Table): Options: - * ``OnDelete.cascade`` (default) - * ``OnDelete.restrict`` - * ``OnDelete.no_action`` - * ``OnDelete.set_null`` - * ``OnDelete.set_default`` + * ``OnDelete.cascade`` (default) + * ``OnDelete.restrict`` + * ``OnDelete.no_action`` + * ``OnDelete.set_null`` + * ``OnDelete.set_default`` To learn more about the different options, see the `Postgres docs `_. @@ -1127,23 +1779,23 @@ class Band(Table): :param on_update: Determines what the database should do when a row has it's primary key - updated. If set to ``OnDelete.cascade``, any rows referencing the + updated. If set to ``OnUpdate.cascade``, any rows referencing the updated row will have their references updated to point to the new primary key. Options: - * ``OnUpdate.cascade`` (default) - * ``OnUpdate.restrict`` - * ``OnUpdate.no_action`` - * ``OnUpdate.set_null`` - * ``OnUpdate.set_default`` + * ``OnUpdate.cascade`` (default) + * ``OnUpdate.restrict`` + * ``OnUpdate.no_action`` + * ``OnUpdate.set_null`` + * ``OnUpdate.set_default`` To learn more about the different options, see the `Postgres docs `_. .. code-block:: python - from piccolo.columns import OnDelete + from piccolo.columns import OnUpdate class Band(Table): name = ForeignKey( @@ -1151,6 +1803,19 @@ class Band(Table): on_update=OnUpdate.cascade ) + :param target_column: + By default the ``ForeignKey`` references the primary key column on the + related table. You can specify an alternative column (it must have a + unique constraint on it though). For example: + + .. code-block:: python + + # Passing in a column reference: + ForeignKey(references=Manager, target_column=Manager.passport_number) + + # Or just the column name: + ForeignKey(references=Manager, target_column='passport_number') + """ # noqa: E501 _foreign_key_meta: ForeignKeyMeta @@ -1158,29 +1823,77 @@ class Band(Table): @property def column_type(self): """ - A ForeignKey column needs to have the same type as the primary key + A ``ForeignKey`` column needs to have the same type as the primary key column of the table being referenced. """ - referenced_table = self._foreign_key_meta.resolved_references - pk_column = referenced_table._meta.primary_key - if isinstance(pk_column, Serial): + target_column = self._foreign_key_meta.resolved_target_column + + if isinstance(target_column, BigSerial): + return BigInt()._get_column_type( + engine_type=self._meta.engine_type + ) + elif isinstance(target_column, Serial): return Integer().column_type else: - return pk_column.column_type + return target_column.column_type + @property + def value_type(self): + """ + The value type matches that of the primary key being referenced. + """ + target_column = self._foreign_key_meta.resolved_target_column + return target_column.value_type + + @overload + def __init__( + self, + references: type[ReferencedTable], + default: Any = None, + null: bool = True, + on_delete: OnDelete = OnDelete.cascade, + on_update: OnUpdate = OnUpdate.cascade, + target_column: Union[str, Column, None] = None, + **kwargs, + ) -> None: ... + + @overload def __init__( self, - references: t.Union[t.Type[Table], LazyTableReference, str], - default: t.Any = None, + references: LazyTableReference, + default: Any = None, null: bool = True, on_delete: OnDelete = OnDelete.cascade, on_update: OnUpdate = OnUpdate.cascade, + target_column: Union[str, Column, None] = None, + **kwargs, + ) -> None: ... + + @overload + def __init__( + self, + references: str, + default: Any = None, + null: bool = True, + on_delete: OnDelete = OnDelete.cascade, + on_update: OnUpdate = OnUpdate.cascade, + target_column: Union[str, Column, None] = None, + **kwargs, + ) -> None: ... + + def __init__( + self, + references: Union[type[ReferencedTable], LazyTableReference, str], + default: Any = None, + null: bool = True, + on_delete: OnDelete = OnDelete.cascade, + on_update: OnUpdate = OnUpdate.cascade, + target_column: Union[str, Column, None] = None, **kwargs, ) -> None: from piccolo.table import Table if inspect.isclass(references): - references = t.cast(t.Type, references) if issubclass(references, Table): # Using this to validate the default value - will raise a # ValueError if incorrect. @@ -1199,25 +1912,251 @@ def __init__( "on_delete": on_delete, "on_update": on_update, "null": null, + "target_column": target_column, } ) super().__init__(**kwargs) - # This is here just for type inference - the actual value is set by - # the Table metaclass. We can't set the actual value here, as - # only the metaclass has access to the table. + # The ``TableMetaclass``` sets the actual value for + # ``ForeignKeyMeta.references``, if the user passed in a string. self._foreign_key_meta = ForeignKeyMeta( - Table, OnDelete.cascade, OnUpdate.cascade + references=Table if isinstance(references, str) else references, + on_delete=on_delete, + on_update=on_update, + target_column=target_column, ) + def _setup(self, table_class: type[Table]) -> ForeignKeySetupResponse: + """ + This is called by the ``TableMetaclass``. A ``ForeignKey`` column can + only be completely setup once it's parent ``Table`` is known. + + :param table_class: + The parent ``Table`` class for this column. + + """ + from piccolo.table import Table + + params = self._meta.params + references = params["references"] + + if isinstance(references, str): + if references == "self": + references = table_class + else: + if "." in references: + # Don't allow relative modules - this may change in + # the future. + if references.startswith("."): + raise ValueError("Relative imports aren't allowed") + + module_path, table_class_name = references.rsplit( + ".", maxsplit=1 + ) + else: + table_class_name = references + module_path = table_class.__module__ + + references = LazyTableReference( + table_class_name=table_class_name, + module_path=module_path, + ) + + is_lazy = isinstance(references, LazyTableReference) + is_table_class = inspect.isclass(references) and issubclass( + references, Table + ) + + if is_lazy or is_table_class: + self._foreign_key_meta.references = references + else: + raise ValueError( + "Error - ``references`` must be a ``Table`` subclass, or " + "a ``LazyTableReference`` instance." + ) + + if is_table_class: + # Record the reverse relationship on the target table. + cast(type[Table], references)._meta._foreign_key_references.append( + self + ) + + # Allow columns on the referenced table to be accessed via + # auto completion. + self.set_proxy_columns() + + return ForeignKeySetupResponse(is_lazy=is_lazy) + def copy(self) -> ForeignKey: column: ForeignKey = copy.copy(self) column._meta = self._meta.copy() column._foreign_key_meta = self._foreign_key_meta.copy() return column - def set_proxy_columns(self): + def all_columns( + self, exclude: Optional[list[Union[Column, str]]] = None + ) -> list[Column]: + """ + Allow a user to access all of the columns on the related table. This is + intended for use with ``select`` queries, and saves the user from + typing out all of the columns by hand. + + For example: + + .. code-block:: python + + await Band.select(Band.name, Band.manager.all_columns()) + + # Equivalent to: + await Band.select( + Band.name, + Band.manager.id, + Band.manager.name + ) + + To exclude certain columns: + + .. code-block:: python + + await Band.select( + Band.name, + Band.manager.all_columns( + exclude=[Band.manager.id] + ) + ) + + :param exclude: + Columns to exclude - can be the name of a column, or a column + instance. For example ``['id']`` or ``[Band.manager.id]``. + + """ + if exclude is None: + exclude = [] + _fk_meta = object.__getattribute__(self, "_foreign_key_meta") + + excluded_column_names = [ + i._meta.name if isinstance(i, Column) else i for i in exclude + ] + + return [ + getattr(self, column._meta.name) + for column in _fk_meta.resolved_references._meta.columns + if column._meta.name not in excluded_column_names + ] + + def reverse(self) -> ForeignKey: + """ + If there's a unique foreign key, this function reverses it. + + .. code-block:: python + + class Band(Table): + name = Varchar() + + class FanClub(Table): + band = ForeignKey(Band, unique=True) + address = Text() + + class Treasurer(Table): + fan_club = ForeignKey(FanClub, unique=True) + name = Varchar() + + It's helpful with ``get_related``, for example: + + .. code-block:: python + + >>> band = await Band.objects().first() + >>> await band.get_related(FanClub.band.reverse()) + + + It works multiple levels deep: + + .. code-block:: python + + >>> await band.get_related(Treasurer.fan_club._.band.reverse()) + + + """ + if not self._meta.unique or any( + not i._meta.unique for i in self._meta.call_chain + ): + raise ValueError("Only reverse unique foreign keys.") + + foreign_keys = [*self._meta.call_chain, self] + + root_foreign_key = foreign_keys[0] + target_column = ( + root_foreign_key._foreign_key_meta.resolved_target_column + ) + foreign_key = target_column.join_on(root_foreign_key) + + call_chain = [] + for fk in reversed(foreign_keys[1:]): + target_column = fk._foreign_key_meta.resolved_target_column + call_chain.append(target_column.join_on(fk)) + + foreign_key._meta.call_chain = call_chain + + return foreign_key + + def all_related( + self, exclude: Optional[list[Union[ForeignKey, str]]] = None + ) -> list[ForeignKey]: + """ + Returns each ``ForeignKey`` column on the related table. This is + intended for use with ``objects`` queries, where you want to return + all of the related tables as nested objects. + + For example: + + .. code-block:: python + + class Band(Table): + name = Varchar() + + class Concert(Table): + name = Varchar() + band_1 = ForeignKey(Band) + band_2 = ForeignKey(Band) + + class Tour(Table): + name = Varchar() + concert = ForeignKey(Concert) + + await Tour.objects(Tour.concert, Tour.concert.all_related()) + + # Equivalent to + await Tour.objects( + Tour.concert, + Tour.concert.band_1, + Tour.concert.band_2 + ) + + :param exclude: + Columns to exclude - can be the name of a column, or a + ``ForeignKey`` instance. For example ``['band_1']`` or + ``[Tour.concert.band_1]``. + + """ + if exclude is None: + exclude = [] + _fk_meta: ForeignKeyMeta = object.__getattribute__( + self, "_foreign_key_meta" + ) + related_fk_columns = ( + _fk_meta.resolved_references._meta.foreign_key_columns + ) + excluded_column_names = [ + i._meta.name if isinstance(i, ForeignKey) else i for i in exclude + ] + return [ + getattr(self, fk_column._meta.name) + for fk_column in related_fk_columns + if fk_column._meta.name not in excluded_column_names + ] + + def set_proxy_columns(self) -> None: """ In order to allow a fluent interface, where tables can be traversed using ForeignKeys (e.g. ``Band.manager.name``), we add attributes to @@ -1230,32 +2169,50 @@ def set_proxy_columns(self): setattr(self, _column._meta.name, _column) _fk_meta.proxy_columns.append(_column) - def all_columns(): - """ - Allow a user to access all of the columns on the related table. + @property + def _(self) -> type[ReferencedTable]: + """ + This allows us specify joins in a way which is friendly to static type + checkers like Mypy and Pyright. + + Whilst this works:: - For example: + # Fetch the band's name, and their manager's name via a foreign + # key: + await Band.select(Band.name, Band.manager.name) - Band.select(Band.name, *Band.manager.all_columns()).run_sync() + There currently isn't a 100% reliable way to tell static type checkers + that ``Band.manager.name`` refers to a ``name`` column on the + ``Manager`` table. - """ - return [ - getattr(self, column._meta.name) - for column in _fk_meta.resolved_references._meta.columns - ] + However, by using the ``_`` property, it works perfectly. Instead + of ``Band.manager.name`` we use ``Band.manager._.name``:: - setattr(self, "all_columns", all_columns) + await Band.select(Band.name, Band.manager._.name) - def __getattribute__(self, name: str): + So when doing joins, after every foreign key we use ``._.`` instead of + ``.``. An easy way to remember this is ``._.`` looks a bit like a + connector in a diagram. + + As Python's typing support increases, we'd love ``Band.manager.name`` + to have the same static typing as ``Band.manager._.name`` (using some + kind of ``Proxy`` type), but for now this is the best solution, and is + a huge improvement in developer experience, as static type checkers + easily know if any of your joins contain typos. + + """ + return cast(type[ReferencedTable], self) + + def __getattribute__(self, name: str) -> Union[Column, Any]: """ Returns attributes unmodified unless they're Column instances, in which case a copy is returned with an updated call_chain (which records the joins required). """ # If the ForeignKey is using a lazy reference, we need to set the - # attributes here. Attributes starting with a double underscore are + # attributes here. Attributes starting with an underscore are # unlikely to be column names. - if not name.startswith("__"): + if not name.startswith("_") and name not in dir(self): try: _foreign_key_meta = object.__getattribute__( self, "_foreign_key_meta" @@ -1268,12 +2225,12 @@ def __getattribute__(self, name: str): ): object.__getattribute__(self, "set_proxy_columns")() - try: - value = object.__getattribute__(self, name) - except AttributeError: - raise AttributeError + value = object.__getattribute__(self, name) - foreignkey_class: t.Type[ForeignKey] = object.__getattribute__( + if name.startswith("_"): + return value + + foreignkey_class: type[ForeignKey] = object.__getattribute__( self, "__class__" ) @@ -1291,7 +2248,7 @@ def __getattribute__(self, name: str): raise Exception("Call chain too long!") foreign_key_meta: ForeignKeyMeta = object.__getattribute__( - self, "_foreign_key_meta" + new_column, "_foreign_key_meta" ) for proxy_column in foreign_key_meta.proxy_columns: @@ -1300,14 +2257,13 @@ def __getattribute__(self, name: str): except Exception: pass + foreign_key_meta.proxy_columns = [] + for ( column ) in value._foreign_key_meta.resolved_references._meta.columns: _column: Column = column.copy() - _column._meta.call_chain = [ - i for i in new_column._meta.call_chain - ] - _column._meta.call_chain.append(new_column) + _column._meta.call_chain = list(new_column._meta.call_chain) setattr(new_column, _column._meta.name, _column) foreign_key_meta.proxy_columns.append(_column) @@ -1318,16 +2274,37 @@ def __getattribute__(self, name: str): column_meta: ColumnMeta = object.__getattribute__(self, "_meta") new_column._meta.call_chain = column_meta.call_chain.copy() + new_column._meta.call_chain.append(self) return new_column else: return value + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> Any: ... + + @overload + def __get__( + self, obj: None, objtype=None + ) -> ForeignKey[ReferencedTable]: ... + + @overload + def __get__(self, obj: Any, objtype=None) -> Any: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Any): + obj.__dict__[self._meta.name] = value + ############################################################################### -class JSON(Column): # lgtm[py/missing-equals] +class JSON(Column): """ Used for storing JSON strings. The data is stored as text. This can be preferable to JSONB if you just want to store and retrieve JSON without @@ -1343,14 +2320,14 @@ class JSON(Column): # lgtm[py/missing-equals] def __init__( self, - default: t.Union[ + default: Union[ str, - t.List, - t.Dict, - t.Callable[[], t.Union[str, t.List, t.Dict]], + list, + dict, + Callable[[], Union[str, list, dict]], None, ] = "{}", - **kwargs, + **kwargs: Unpack[ColumnKwargs], ) -> None: self._validate_default(default, (str, list, dict, None)) @@ -1358,18 +2335,110 @@ def __init__( default = dump_json(default) self.default = default - kwargs.update({"default": default}) - super().__init__(**kwargs) + super().__init__(default=default, **kwargs) + + self.json_operator: Optional[str] = None - self.json_operator: t.Optional[str] = None + @property + def column_type(self): + engine_type = self._meta.engine_type + if engine_type == "cockroach": + return "JSONB" # Cockroach is always JSONB. + else: + return "JSON" + + ########################################################################### + + def arrow(self, key: Union[str, int, QueryString]) -> GetChildElement: + """ + Allows a child element of the JSON structure to be returned - for + example:: + + >>> await RecordingStudio.select( + ... RecordingStudio.facilities.arrow("restaurant") + ... ) + + """ + from piccolo.query.operators.json import GetChildElement + + alias = self._alias or self._meta.get_default_alias() + return GetChildElement(identifier=self, key=key, alias=alias) + + def __getitem__( + self, value: Union[str, int, QueryString] + ) -> GetChildElement: + """ + A shortcut for the ``arrow`` method, used for retrieving a child + element. + + For example: + + .. code-block:: python + + >>> await RecordingStudio.select( + ... RecordingStudio.facilities["restaurant"] + ... ) + + """ + return self.arrow(key=value) + + def from_path( + self, + path: list[Union[str, int]], + ) -> GetElementFromPath: + """ + Allows an element of the JSON structure to be returned, which can be + arbitrarily deep. For example:: + + >>> await RecordingStudio.select( + ... RecordingStudio.facilities.from_path([ + ... "technician", + ... 0, + ... "first_name" + ... ]) + ... ) + + It's the same as calling ``arrow`` multiple times, but is more + efficient / convenient if extracting highly nested data:: + + >>> await RecordingStudio.select( + ... RecordingStudio.facilities.arrow( + ... "technician" + ... ).arrow( + ... 0 + ... ).arrow( + ... "first_name" + ... ) + ... ) + + """ + from piccolo.query.operators.json import GetElementFromPath + + alias = self._alias or self._meta.get_default_alias() + return GetElementFromPath(identifier=self, path=path, alias=alias) + + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> str: ... + + @overload + def __get__(self, obj: None, objtype=None) -> JSON: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[str, dict]): + obj.__dict__[self._meta.name] = value class JSONB(JSON): """ - Used for storing JSON strings - Postgres only. The data is stored in a - binary format, and can be queried. Insertion can be slower (as it needs to - be converted to the binary format). The benefits of JSONB generally - outweigh the downsides. + Used for storing JSON strings - Postgres / CochroachDB only. The data is + stored in a binary format, and can be queried more efficiently. Insertion + can be slower (as it needs to be converted to the binary format). The + benefits of JSONB generally outweigh the downsides. :param default: Either a JSON string can be provided, or a Python ``dict`` or ``list`` @@ -1377,24 +2446,24 @@ class JSONB(JSON): """ - def arrow(self, key: str) -> JSONB: - """ - Allows part of the JSON structure to be returned - for example, - for {"a": 1}, and a key value of "a", then 1 will be returned. - """ - instance = t.cast(JSONB, self.copy()) - instance.json_operator = f"-> '{key}'" - return instance - - def get_select_string(self, engine_type: str, just_alias=False) -> str: - select_string = self._meta.get_full_name(just_alias=just_alias) - if self.json_operator is None: - return select_string - else: - if self.alias is None: - return f"{select_string} {self.json_operator}" - else: - return f"{select_string} {self.json_operator} AS {self.alias}" + @property + def column_type(self): + return "JSONB" # Must be defined, we override column_type() in JSON() + + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> str: ... + + @overload + def __get__(self, obj: None, objtype=None) -> JSONB: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: Union[str, dict]): + obj.__dict__[self._meta.name] = value ############################################################################### @@ -1412,10 +2481,10 @@ class Token(Table): token = Bytea(default=b'token123') # Create - >>> Token(token=b'my-token').save().run_sync() + >>> await Token(token=b'my-token').save() # Query - >>> Token.select(Token.token).run_sync() + >>> await Token.select(Token.token) {'token': b'my-token'} """ @@ -1424,8 +2493,8 @@ class Token(Table): @property def column_type(self): - engine_type = self._meta.table._meta.db.engine_type - if engine_type == "postgres": + engine_type = self._meta.engine_type + if engine_type in ("postgres", "cockroach"): return "BYTEA" elif engine_type == "sqlite": return "BLOB" @@ -1433,15 +2502,15 @@ def column_type(self): def __init__( self, - default: t.Union[ + default: Union[ bytes, bytearray, Enum, - t.Callable[[], bytes], - t.Callable[[], bytearray], + Callable[[], bytes], + Callable[[], bytearray], None, ] = b"", - **kwargs, + **kwargs: Unpack[ColumnKwargs], ) -> None: self._validate_default(default, (bytes, bytearray, None)) @@ -1449,8 +2518,22 @@ def __init__( default = bytes(default) self.default = default - kwargs.update({"default": default}) - super().__init__(**kwargs) + super().__init__(default=default, **kwargs) + + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> bytes: ... + + @overload + def __get__(self, obj: None, objtype=None) -> Bytea: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: bytes): + obj.__dict__[self._meta.name] = value class Blob(Bytea): @@ -1458,12 +2541,47 @@ class Blob(Bytea): An alias for Bytea. """ - pass + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> bytes: ... + + @overload + def __get__(self, obj: None, objtype=None) -> Blob: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: bytes): + obj.__dict__[self._meta.name] = value ############################################################################### +class ListProxy: + """ + Sphinx's autodoc fails if we have this function signature:: + + class Array(Column): + + def __init__(default=list): + ... + + We can't use ``list`` as a default value without breaking autodoc (it + doesn't seem to like it when a class type is used as a default), so + instead we assign an instance of this class. It keeps both autodoc and MyPy + happy. In ``Array.__init__`` we then swap it out for ``list``. + """ + + def __call__(self): + return [] + + def __repr__(self): + return "list" + + class Array(Column): """ Used for storing lists of data. @@ -1476,10 +2594,10 @@ class Ticket(Table): seat_numbers = Array(base_column=Integer()) # Create - >>> Ticket(seat_numbers=[34, 35, 36]).save().run_sync() + >>> await Ticket(seat_numbers=[34, 35, 36]).save() # Query - >>> Ticket.select(Ticket.seat_numbers).run_sync() + >>> await Ticket.select(Ticket.seat_numbers) {'seat_numbers': [34, 35, 36]} """ @@ -1489,40 +2607,116 @@ class Ticket(Table): def __init__( self, base_column: Column, - default: t.Union[t.List, Enum, t.Callable[[], t.List], None] = list, - **kwargs, + default: Union[list, Enum, Callable[[], list], None] = ListProxy(), + **kwargs: Unpack[ColumnKwargs], ) -> None: if isinstance(base_column, ForeignKey): raise ValueError("Arrays of ForeignKeys aren't allowed.") + # This is a workaround because having `list` as a default breaks + # Sphinx's autodoc. + if isinstance(default, ListProxy): + default = list + self._validate_default(default, (list, None)) + choices = kwargs.get("choices") + if choices is not None: + self._validate_choices( + choices, allowed_type=base_column.value_type + ) + self._validated_choices = True + # Usually columns are given a name by the Table metaclass, but in this - # case we have to assign one manually. + # case we have to assign one manually to the base column. base_column._meta._name = base_column.__class__.__name__ self.base_column = base_column self.default = default - self.index: t.Optional[int] = None - kwargs.update({"base_column": base_column, "default": default}) - super().__init__(**kwargs) + self.index: Optional[int] = None + super().__init__(default=default, base_column=base_column, **kwargs) @property def column_type(self): - engine_type = self._meta.table._meta.db.engine_type - if engine_type == "postgres": + engine_type = self._meta.engine_type + if engine_type in ("postgres", "cockroach"): return f"{self.base_column.column_type}[]" elif engine_type == "sqlite": - return "ARRAY" + inner_column = self._get_inner_column() + return ( + f"ARRAY_{inner_column.column_type}" + if isinstance( + inner_column, (Date, Timestamp, Timestamptz, Time) + ) + else "ARRAY" + ) raise Exception("Unrecognized engine type") + def _setup_base_column(self, table_class: type[Table]): + """ + Called from the ``Table.__init_subclass__`` - makes sure + that the ``base_column`` has a reference to the parent table. + """ + self.base_column._meta._table = table_class + if isinstance(self.base_column, Array): + self.base_column._setup_base_column(table_class=table_class) + + def _get_dimensions(self, start: int = 0) -> int: + """ + A helper function to get the number of dimensions for the array. For + example:: + + >>> Array(Varchar())._get_dimensions() + 1 + + >>> Array(Array(Varchar()))._get_dimensions() + 2 + + :param start: + Ignore this - it's just used for calling this method recursively. + + """ + if isinstance(self.base_column, Array): + return self.base_column._get_dimensions(start=start + 1) + else: + return start + 1 + + def _get_inner_column(self) -> Column: + """ + A helper function to get the innermost ``Column`` for the array. For + example:: + + >>> Array(Varchar())._get_inner_column() + Varchar + + >>> Array(Array(Varchar()))._get_inner_column() + Varchar + + """ + if isinstance(self.base_column, Array): + return self.base_column._get_inner_column() + else: + return self.base_column + + def _get_inner_value_type(self) -> type: + """ + A helper function to get the innermost value type for the array. For + example:: + + >>> Array(Varchar())._get_inner_value_type() + str + + >>> Array(Array(Varchar()))._get_inner_value_type() + str + + """ + return self._get_inner_column().value_type + def __getitem__(self, value: int) -> Array: """ Allows queries which retrieve an item from the array. The index starts with 0 for the first value. If you were to write the SQL by hand, the - first index would be 1 instead: - - https://www.postgresql.org/docs/current/arrays.html + first index would be 1 instead (see `Postgres array docs `_). However, we keep the first index as 0 to fit better with Python. @@ -1530,22 +2724,22 @@ def __getitem__(self, value: int) -> Array: .. code-block:: python - >>> Ticket.select(Ticket.seat_numbers[0]).first().run_sync + >>> await Ticket.select(Ticket.seat_numbers[0]).first() {'seat_numbers': 325} - """ - engine_type = self._meta.table._meta.db.engine_type - if engine_type != "postgres": + """ # noqa: E501 + engine_type = self._meta.engine_type + if engine_type != "postgres" and engine_type != "cockroach": raise ValueError( - "Only Postgres supports array indexing currently." + "Only Postgres and Cockroach support array indexing." ) if isinstance(value, int): if value < 0: raise ValueError("Only positive integers are allowed.") - instance = t.cast(Array, self.copy()) + instance = cast(Array, self.copy()) # We deliberately add 1, as Postgres treats the first array element # as index 1. @@ -1554,46 +2748,230 @@ def __getitem__(self, value: int) -> Array: else: raise ValueError("Only integers can be used for indexing.") - def get_select_string(self, engine_type: str, just_alias=False) -> str: - select_string = self._meta.get_full_name(just_alias=just_alias) + def get_select_string( + self, engine_type: str, with_alias=True + ) -> QueryString: + select_string = self._meta.get_full_name(with_alias=False) if isinstance(self.index, int): - return f"{select_string}[{self.index}]" - else: - return select_string + select_string += f"[{self.index}]" + + if with_alias: + alias = self._alias or self._meta.get_default_alias() + select_string += f' AS "{alias}"' - def any(self, value: t.Any) -> Where: + return QueryString(select_string) + + def any(self, value: Any) -> Where: """ Check if any of the items in the array match the given value. .. code-block:: python - >>> Ticket.select().where(Ticket.seat_numbers.any(510)).run_sync() + >>> await Ticket.select().where(Ticket.seat_numbers.any(510)) """ - engine_type = self._meta.table._meta.db.engine_type + engine_type = self._meta.engine_type - if engine_type == "postgres": + if engine_type in ("postgres", "cockroach"): return Where(column=self, value=value, operator=ArrayAny) elif engine_type == "sqlite": return self.like(f"%{value}%") else: raise ValueError("Unrecognised engine type") - def all(self, value: t.Any) -> Where: + def not_any(self, value: Any) -> Where: + """ + Check if the given value isn't in the array. + + .. code-block:: python + + >>> await Ticket.select().where(Ticket.seat_numbers.not_any(510)) + + """ + engine_type = self._meta.engine_type + + if engine_type in ("postgres", "cockroach"): + return Where(column=self, value=value, operator=ArrayNotAny) + elif engine_type == "sqlite": + return self.not_like(f"%{value}%") + else: + raise ValueError("Unrecognised engine type") + + def all(self, value: Any) -> Where: """ Check if all of the items in the array match the given value. .. code-block:: python - >>> Ticket.select().where(Ticket.seat_numbers.all(510)).run_sync() + >>> await Ticket.select().where(Ticket.seat_numbers.all(510)) """ - engine_type = self._meta.table._meta.db.engine_type + engine_type = self._meta.engine_type - if engine_type == "postgres": + if engine_type in ("postgres", "cockroach"): return Where(column=self, value=value, operator=ArrayAll) elif engine_type == "sqlite": raise ValueError("Unsupported by SQLite") else: raise ValueError("Unrecognised engine type") + + def cat(self, value: ArrayType) -> QueryString: + """ + A convenient way of accessing the + :class:`ArrayCat ` function. + + Used in an ``update`` query to concatenate two arrays. + + .. code-block:: python + + >>> await Ticket.update({ + ... Ticket.seat_numbers: Ticket.seat_numbers.cat([1000]) + ... }).where(Ticket.id == 1) + + You can also use the ``+`` symbol if you prefer. To concatenate to + the end: + + .. code-block:: python + + >>> await Ticket.update({ + ... Ticket.seat_numbers: Ticket.seat_numbers + [1000] + ... }).where(Ticket.id == 1) + + To concatenate to the start: + + .. code-block:: python + + >>> await Ticket.update({ + ... Ticket.seat_numbers: [1000] + Ticket.seat_numbers + ... }).where(Ticket.id == 1) + + You can concatenate multiple arrays in one go: + + .. code-block:: python + + >>> await Ticket.update({ + ... Ticket.seat_numbers: [1000] + Ticket.seat_numbers + [2000] + ... }).where(Ticket.id == 1) + + .. note:: Postgres / CockroachDB only + + """ + from piccolo.query.functions.array import ArrayCat + + # Keep this for backwards compatibility - we had this as a convenience + # for users, but it would be nice to remove it in the future. + if not isinstance(value, list): + value = [value] + + return ArrayCat(array_1=self, array_2=value) + + def remove(self, value: ArrayItemType) -> QueryString: + """ + A convenient way of accessing the + :class:`ArrayRemove ` + function. + + Used in an ``update`` query to remove an item from an array. + + .. code-block:: python + + >>> await Ticket.update({ + ... Ticket.seat_numbers: Ticket.seat_numbers.remove(1000) + ... }).where(Ticket.id == 1) + + .. note:: Postgres / CockroachDB only + + """ + from piccolo.query.functions.array import ArrayRemove + + return ArrayRemove(array=self, value=value) + + def prepend(self, value: ArrayItemType) -> QueryString: + """ + A convenient way of accessing the + :class:`ArrayPrepend ` + function. + + Used in an ``update`` query to prepend an item to an array. + + .. code-block:: python + + >>> await Ticket.update({ + ... Ticket.seat_numbers: Ticket.seat_numbers.prepend(1000) + ... }).where(Ticket.id == 1) + + .. note:: Postgres / CockroachDB only + + """ + from piccolo.query.functions.array import ArrayPrepend + + return ArrayPrepend(array=self, value=value) + + def append(self, value: ArrayItemType) -> QueryString: + """ + A convenient way of accessing the + :class:`ArrayAppend ` + function. + + Used in an ``update`` query to append an item to an array. + + .. code-block:: python + + >>> await Ticket.update({ + ... Ticket.seat_numbers: Ticket.seat_numbers.append(1000) + ... }).where(Ticket.id == 1) + + .. note:: Postgres / CockroachDB only + + """ + from piccolo.query.functions.array import ArrayAppend + + return ArrayAppend(array=self, value=value) + + def replace( + self, old_value: ArrayItemType, new_value: ArrayItemType + ) -> QueryString: + """ + A convenient way of accessing the + :class:`ArrayReplace ` + function. + + Used in an ``update`` query to replace each array item + equal to the given value with a new value. + + .. code-block:: python + + >>> await Ticket.update({ + ... Ticket.seat_numbers: Ticket.seat_numbers.replace(1000, 500) + ... }).where(Ticket.id == 1) + + .. note:: Postgres / CockroachDB only + + """ + from piccolo.query.functions.array import ArrayReplace + + return ArrayReplace(self, old_value=old_value, new_value=new_value) + + def __add__(self, value: ArrayType) -> QueryString: + return self.cat(value) + + def __radd__(self, value: ArrayType) -> QueryString: + from piccolo.query.functions.array import ArrayCat + + return ArrayCat(array_1=value, array_2=self) + + ########################################################################### + # Descriptors + + @overload + def __get__(self, obj: Table, objtype=None) -> list[Any]: ... + + @overload + def __get__(self, obj: None, objtype=None) -> Array: ... + + def __get__(self, obj, objtype=None): + return obj.__dict__[self._meta.name] if obj else self + + def __set__(self, obj, value: list[Any]): + obj.__dict__[self._meta.name] = value diff --git a/piccolo/columns/combination.py b/piccolo/columns/combination.py index cca8296da..79e14bdb5 100644 --- a/piccolo/columns/combination.py +++ b/piccolo/columns/combination.py @@ -1,17 +1,24 @@ from __future__ import annotations -import typing as t - -from piccolo.columns.operators.comparison import ComparisonOperator -from piccolo.custom_types import Combinable, Iterable +from typing import TYPE_CHECKING, Any, Union + +from piccolo.columns.operators.comparison import ( + ComparisonOperator, + Equal, + IsNull, +) +from piccolo.custom_types import Combinable, CustomIterable from piccolo.querystring import QueryString from piccolo.utils.sql_values import convert_to_sql_value -if t.TYPE_CHECKING: +if TYPE_CHECKING: from piccolo.columns.base import Column class CombinableMixin(object): + + __slots__ = () + def __and__(self, value: Combinable) -> "And": return And(self, value) # type: ignore @@ -37,13 +44,56 @@ def querystring(self) -> QueryString: self.second.querystring, ) + @property + def querystring_for_update_and_delete(self) -> QueryString: + return QueryString( + "({} " + self.operator + " {})", + self.first.querystring_for_update_and_delete, + self.second.querystring_for_update_and_delete, + ) + def __str__(self): - self.querystring.__str__() + return self.querystring.__str__() class And(Combination): operator = "AND" + def get_column_values(self) -> dict[Column, Any]: + """ + This is used by `get_or_create` to know which values to assign if + the row doesn't exist in the database. + + For example, if we have:: + + (Band.name == 'Pythonistas') & (Band.popularity == 1000) + + We will return:: + + {Band.name: 'Pythonistas', Band.popularity: 1000}. + + If the operator is anything besides equals, we don't return it, for + example:: + + (Band.name == 'Pythonistas') & (Band.popularity > 1000) + + Returns:: + + {Band.name: 'Pythonistas'} + + """ + output = {} + for combinable in (self.first, self.second): + if isinstance(combinable, Where): + if combinable.operator == Equal: + output[combinable.column] = combinable.value + elif combinable.operator == IsNull: + output[combinable.column] = None + elif isinstance(combinable, And): + output.update(combinable.get_column_values()) + + return output + class Or(Combination): operator = "OR" @@ -57,22 +107,33 @@ class Undefined: class WhereRaw(CombinableMixin): - def __init__(self, sql: str, *args: t.Any) -> None: + __slots__ = ("querystring",) + + def __init__(self, sql: str, *args: Any) -> None: """ Execute raw SQL queries in your where clause. Use with caution! - await Band.where( - WhereRaw("name = 'Pythonistas'") - ) + .. code-block:: python + + await Band.where( + WhereRaw("name = 'Pythonistas'") + ) Or passing in parameters: - await Band.where( - WhereRaw("name = {}", 'Pythonistas') - ) + .. code-block:: python + + await Band.where( + WhereRaw("name = {}", 'Pythonistas') + ) + """ self.querystring = QueryString(sql, *args) + @property + def querystring_for_update_and_delete(self) -> QueryString: + return self.querystring + def __str__(self): return self.querystring.__str__() @@ -84,43 +145,43 @@ class Where(CombinableMixin): def __init__( self, column: Column, - value: t.Any = UNDEFINED, - values: t.Union[Iterable, Undefined] = UNDEFINED, - operator: t.Type[ComparisonOperator] = ComparisonOperator, + value: Any = UNDEFINED, + values: Union[CustomIterable, Undefined, QueryString] = UNDEFINED, + operator: type[ComparisonOperator] = ComparisonOperator, ) -> None: """ We use the UNDEFINED value to show the value was deliberately omitted, vs None, which is a valid value for a where clause. """ self.column = column - self.value = self.clean_value(value) - if values == UNDEFINED: + self.value = value if value == UNDEFINED else self.clean_value(value) + if (values == UNDEFINED) or isinstance(values, QueryString): self.values = values else: self.values = [self.clean_value(i) for i in values] # type: ignore self.operator = operator - def clean_value(self, value: t.Any) -> t.Any: + def clean_value(self, value: Any) -> Any: """ - If a where clause contains a Table instance, we should convert that + If a where clause contains a ``Table`` instance, we should convert that to a column reference. For example: .. code-block:: python - manager = Manager.objects.where( + manager = await Manager.objects.where( Manager.name == 'Guido' - ).first().run_sync() + ).first() # The where clause should be: - Band.select().where(Band.manager.id == guido.id).run_sync() + await Band.select().where(Band.manager.id == guido.id) # Or - Band.select().where(Band.manager == guido.id).run_sync() + await Band.select().where(Band.manager == guido.id) # If the object is passed in, i.e. `guido` instead of `guido.id`, # it should still work. - Band.select().where(Band.manager == guido).run_sync() + await Band.select().where(Band.manager == guido) Also, convert Enums to their underlying values, and serialise any JSON. @@ -131,15 +192,18 @@ def clean_value(self, value: t.Any) -> t.Any: def values_querystring(self) -> QueryString: values = self.values + if isinstance(values, QueryString): + return values + if isinstance(values, Undefined): raise ValueError("values is undefined") - template = ", ".join(["{}" for _ in values]) + template = ", ".join("{}" for _ in values) return QueryString(template, *values) @property def querystring(self) -> QueryString: - args: t.List[t.Any] = [] + args: list[Any] = [] if self.value != UNDEFINED: args.append(self.value) @@ -156,5 +220,37 @@ def querystring(self) -> QueryString: return QueryString(template, *args) + @property + def querystring_for_update_and_delete(self) -> QueryString: + args: list[Any] = [] + if self.value != UNDEFINED: + args.append(self.value) + + if self.values != UNDEFINED: + args.append(self.values_querystring) + + column = self.column + + if column._meta.call_chain: + # Use a sub select to find the correct ID. + root_column = column._meta.call_chain[0] + sub_query = root_column._meta.table.select(root_column).where(self) + + column_name = column._meta.call_chain[0]._meta.name + return QueryString( + f"{column_name} IN ({{}})", + sub_query.querystrings[0], + ) + else: + template = self.operator.template.format( + name=self.column.get_where_string( + engine_type=self.column._meta.engine_type + ), + value="{}", + values="{}", + ) + + return QueryString(template, *args) + def __str__(self): return self.querystring.__str__() diff --git a/piccolo/columns/defaults/base.py b/piccolo/columns/defaults/base.py index 1346a2964..fcf46bd85 100644 --- a/piccolo/columns/defaults/base.py +++ b/piccolo/columns/defaults/base.py @@ -1,32 +1,35 @@ from __future__ import annotations -import typing as t -from abc import ABC, abstractmethod, abstractproperty +from abc import ABC, abstractmethod +from typing import Any from piccolo.utils.repr import repr_class_instance class Default(ABC): - @abstractproperty + @property + @abstractmethod def postgres(self) -> str: pass - @abstractproperty + @property + @abstractmethod def sqlite(self) -> str: pass @abstractmethod - def python(self): + def python(self) -> Any: pass - def get_postgres_interval_string(self, attributes: t.List[str]) -> str: + def get_postgres_interval_string(self, attributes: list[str]) -> str: """ Returns a string usable as an interval argument in Postgres e.g. "1 day 2 hour". - :arg attributes: + :param attributes: Date / time attributes to extract from the instance. e.g. ['hours', 'minutes'] + """ interval_components = [] for attr_name in attributes: @@ -36,14 +39,15 @@ def get_postgres_interval_string(self, attributes: t.List[str]) -> str: return " ".join(interval_components) - def get_sqlite_interval_string(self, attributes: t.List[str]) -> str: + def get_sqlite_interval_string(self, attributes: list[str]) -> str: """ Returns a string usable as an interval argument in SQLite e.g. "'-2 hours', '1 days'". - :arg attributes: + :param attributes: Date / time attributes to extract from the instance. e.g. ['hours', 'minutes'] + """ interval_components = [] for attr_name in attributes: diff --git a/piccolo/columns/defaults/date.py b/piccolo/columns/defaults/date.py index b6fecf86f..b802c6764 100644 --- a/piccolo/columns/defaults/date.py +++ b/piccolo/columns/defaults/date.py @@ -1,14 +1,34 @@ from __future__ import annotations import datetime -import typing as t +from collections.abc import Callable from enum import Enum +from typing import Union from .base import Default class DateOffset(Default): + """ + This makes the default value for a + :class:`Date ` column the current date, + but offset by a number of days. + + For example, if you wanted the default to be tomorrow, you can specify + ``DateOffset(days=1)``: + + .. code-block:: python + + class DiscountCode(Table): + expires = Date(default=DateOffset(days=1)) + + """ + def __init__(self, days: int): + """ + :param days: + The number of days to offset. + """ self.days = days @property @@ -16,6 +36,10 @@ def postgres(self): interval_string = self.get_postgres_interval_string(["days"]) return f"CURRENT_DATE + INTERVAL '{interval_string}'" + @property + def cockroach(self): + return self.postgres + @property def sqlite(self): interval_string = self.get_sqlite_interval_string(["days"]) @@ -32,6 +56,10 @@ class DateNow(Default): def postgres(self): return "CURRENT_DATE" + @property + def cockroach(self): + return self.postgres + @property def sqlite(self): return "CURRENT_DATE" @@ -56,6 +84,10 @@ def __init__( def postgres(self): return f"'{self.date.isoformat()}'" + @property + def cockroach(self): + return self.postgres + @property def sqlite(self): return f"'{self.date.isoformat()}'" @@ -71,7 +103,15 @@ def from_date(cls, instance: datetime.date): # Might add an enum back which encapsulates all of the options. -DateArg = t.Union[DateOffset, DateCustom, DateNow, Enum, None, datetime.date] +DateArg = Union[ + DateOffset, + DateCustom, + DateNow, + Enum, + None, + datetime.date, + Callable[[], datetime.date], +] __all__ = ["DateArg", "DateOffset", "DateCustom", "DateNow"] diff --git a/piccolo/columns/defaults/interval.py b/piccolo/columns/defaults/interval.py index b133a0e8b..798a4a050 100644 --- a/piccolo/columns/defaults/interval.py +++ b/piccolo/columns/defaults/interval.py @@ -1,13 +1,14 @@ from __future__ import annotations import datetime -import typing as t +from collections.abc import Callable from enum import Enum +from typing import Union from .base import Default -class IntervalCustom(Default): # lgtm [py/missing-equals] +class IntervalCustom(Default): def __init__( self, weeks: int = 0, @@ -53,6 +54,10 @@ def postgres(self): ) return f"'{value}'" + @property + def cockroach(self): + return self.postgres + @property def sqlite(self): return self.timedelta.total_seconds() @@ -71,11 +76,12 @@ def from_timedelta(cls, instance: datetime.timedelta): ############################################################################### -IntervalArg = t.Union[ +IntervalArg = Union[ IntervalCustom, Enum, None, datetime.timedelta, + Callable[[], datetime.timedelta], ] diff --git a/piccolo/columns/defaults/time.py b/piccolo/columns/defaults/time.py index fc22567d9..a32dcdf47 100644 --- a/piccolo/columns/defaults/time.py +++ b/piccolo/columns/defaults/time.py @@ -1,8 +1,9 @@ from __future__ import annotations import datetime -import typing as t +from collections.abc import Callable from enum import Enum +from typing import Union from .base import Default @@ -20,6 +21,13 @@ def postgres(self): ) return f"CURRENT_TIME + INTERVAL '{interval_string}'" + @property + def cockroach(self): + interval_string = self.get_postgres_interval_string( + ["hours", "minutes", "seconds"] + ) + return f"CURRENT_TIME::TIMESTAMP + INTERVAL '{interval_string}'" + @property def sqlite(self): interval_string = self.get_sqlite_interval_string( @@ -41,6 +49,10 @@ class TimeNow(Default): def postgres(self): return "CURRENT_TIME" + @property + def cockroach(self): + return "CURRENT_TIME::TIMESTAMP" + @property def sqlite(self): return "CURRENT_TIME" @@ -60,6 +72,10 @@ def __init__(self, hour: int, minute: int, second: int): def postgres(self): return f"'{self.time.isoformat()}'" + @property + def cockroach(self): + return f"'{self.time.isoformat()}'::TIMESTAMP" + @property def sqlite(self): return f"'{self.time.isoformat()}'" @@ -74,7 +90,15 @@ def from_time(cls, instance: datetime.time): ) -TimeArg = t.Union[TimeCustom, TimeNow, TimeOffset, Enum, None, datetime.time] +TimeArg = Union[ + TimeCustom, + TimeNow, + TimeOffset, + Enum, + None, + datetime.time, + Callable[[], datetime.time], +] __all__ = ["TimeArg", "TimeCustom", "TimeNow", "TimeOffset"] diff --git a/piccolo/columns/defaults/timestamp.py b/piccolo/columns/defaults/timestamp.py index c1264a284..11388c694 100644 --- a/piccolo/columns/defaults/timestamp.py +++ b/piccolo/columns/defaults/timestamp.py @@ -1,8 +1,9 @@ from __future__ import annotations import datetime -import typing as t +from collections.abc import Callable from enum import Enum +from typing import Union from .base import Default @@ -23,6 +24,13 @@ def postgres(self): ) return f"CURRENT_TIMESTAMP + INTERVAL '{interval_string}'" + @property + def cockroach(self): + interval_string = self.get_postgres_interval_string( + ["days", "hours", "minutes", "seconds"] + ) + return f"CURRENT_TIMESTAMP::TIMESTAMP + INTERVAL '{interval_string}'" + @property def sqlite(self): interval_string = self.get_sqlite_interval_string( @@ -44,6 +52,10 @@ class TimestampNow(Default): def postgres(self): return "current_timestamp" + @property + def cockroach(self): + return "current_timestamp::TIMESTAMP" + @property def sqlite(self): return "current_timestamp" @@ -59,6 +71,7 @@ def __init__( month: int = 1, day: int = 1, hour: int = 0, + minute: int = 0, second: int = 0, microsecond: int = 0, ): @@ -66,6 +79,7 @@ def __init__( self.month = month self.day = day self.hour = hour + self.minute = minute self.second = second self.microsecond = microsecond @@ -76,6 +90,7 @@ def datetime(self): month=self.month, day=self.day, hour=self.hour, + minute=self.minute, second=self.second, microsecond=self.microsecond, ) @@ -84,6 +99,12 @@ def datetime(self): def postgres(self): return "'{}'".format(self.datetime.isoformat().replace("T", " ")) + @property + def cockroach(self): + return "'{}'::TIMESTAMP".format( + self.datetime.isoformat().replace("T", " ") + ) + @property def sqlite(self): return "'{}'".format(self.datetime.isoformat().replace("T", " ")) @@ -96,8 +117,9 @@ def from_datetime(cls, instance: datetime.datetime): # type: ignore return cls( year=instance.year, month=instance.month, - day=instance.month, + day=instance.day, hour=instance.hour, + minute=instance.minute, second=instance.second, microsecond=instance.microsecond, ) @@ -113,7 +135,7 @@ class DatetimeDefault: ############################################################################### -TimestampArg = t.Union[ +TimestampArg = Union[ TimestampCustom, TimestampNow, TimestampOffset, @@ -121,6 +143,7 @@ class DatetimeDefault: None, datetime.datetime, DatetimeDefault, + Callable[[], datetime.datetime], ] diff --git a/piccolo/columns/defaults/timestamptz.py b/piccolo/columns/defaults/timestamptz.py index e52c2e62a..1cb6d32ff 100644 --- a/piccolo/columns/defaults/timestamptz.py +++ b/piccolo/columns/defaults/timestamptz.py @@ -1,13 +1,21 @@ from __future__ import annotations import datetime -import typing as t +from collections.abc import Callable from enum import Enum +from typing import Union from .timestamp import TimestampCustom, TimestampNow, TimestampOffset class TimestamptzOffset(TimestampOffset): + @property + def cockroach(self): + interval_string = self.get_postgres_interval_string( + ["days", "hours", "minutes", "seconds"] + ) + return f"CURRENT_TIMESTAMP + INTERVAL '{interval_string}'" + def python(self): return datetime.datetime.now( tz=datetime.timezone.utc @@ -20,11 +28,19 @@ def python(self): class TimestamptzNow(TimestampNow): + @property + def cockroach(self): + return "current_timestamp" + def python(self): return datetime.datetime.now(tz=datetime.timezone.utc) class TimestamptzCustom(TimestampCustom): + @property + def cockroach(self): + return "'{}'".format(self.datetime.isoformat().replace("T", " ")) + @property def datetime(self): return datetime.datetime( @@ -32,6 +48,7 @@ def datetime(self): month=self.month, day=self.day, hour=self.hour, + minute=self.minute, second=self.second, microsecond=self.microsecond, tzinfo=datetime.timezone.utc, @@ -44,20 +61,22 @@ def from_datetime(cls, instance: datetime.datetime): # type: ignore return cls( year=instance.year, month=instance.month, - day=instance.month, + day=instance.day, hour=instance.hour, + minute=instance.minute, second=instance.second, microsecond=instance.microsecond, ) -TimestamptzArg = t.Union[ +TimestamptzArg = Union[ TimestamptzCustom, TimestamptzNow, TimestamptzOffset, Enum, None, datetime.datetime, + Callable[[], datetime.datetime], ] diff --git a/piccolo/columns/defaults/uuid.py b/piccolo/columns/defaults/uuid.py index 44a5c76c0..ad0e34679 100644 --- a/piccolo/columns/defaults/uuid.py +++ b/piccolo/columns/defaults/uuid.py @@ -1,14 +1,27 @@ -import typing as t import uuid +from collections.abc import Callable from enum import Enum +from typing import Union from .base import Default class UUID4(Default): + """ + This makes the default value for a + :class:`UUID ` column a randomly + generated UUID v4 value. The advantage over using :func:`uuid.uuid4` from + the standard library, is the default is set on the column definition in the + database too. + """ + @property def postgres(self): - return "uuid_generate_v4()" + return "gen_random_uuid()" + + @property + def cockroach(self): + return self.postgres @property def sqlite(self): @@ -18,7 +31,7 @@ def python(self): return uuid.uuid4() -UUIDArg = t.Union[UUID4, uuid.UUID, str, Enum, None] +UUIDArg = Union[UUID4, uuid.UUID, str, Enum, None, Callable[[], uuid.UUID]] __all__ = ["UUIDArg", "UUID4"] diff --git a/piccolo/columns/indexes.py b/piccolo/columns/indexes.py index 9b25a1f27..79060277f 100644 --- a/piccolo/columns/indexes.py +++ b/piccolo/columns/indexes.py @@ -2,6 +2,11 @@ class IndexMethod(str, Enum): + """ + Used to specify the index method for a + :class:`Column `. + """ + btree = "btree" hash = "hash" gist = "gist" diff --git a/piccolo/columns/m2m.py b/piccolo/columns/m2m.py new file mode 100644 index 000000000..b8d34d46b --- /dev/null +++ b/piccolo/columns/m2m.py @@ -0,0 +1,447 @@ +from __future__ import annotations + +import inspect +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Union + +from piccolo.columns.column_types import ( + JSON, + JSONB, + Column, + ForeignKey, + LazyTableReference, +) +from piccolo.querystring import QueryString, Selectable +from piccolo.utils.list import flatten +from piccolo.utils.sync import run_sync + +if TYPE_CHECKING: # pragma: no cover + from piccolo.table import Table + + +class M2MSelect(Selectable): + """ + This is a subquery used within a select to fetch data via an M2M table. + """ + + def __init__( + self, + *columns: Column, + m2m: M2M, + as_list: bool = False, + load_json: bool = False, + ): + """ + :param columns: + Which columns to include from the related table. + :param as_list: + If a single column is provided, and ``as_list`` is ``True`` a + flattened list will be returned, rather than a list of objects. + :param load_json: + If ``True``, any JSON strings are loaded as Python objects. + + """ + self.as_list = as_list + self.columns = columns + self.m2m = m2m + self.load_json = load_json + + safe_types = (int, str) + + # If the columns can be serialised / deserialised as JSON, then we + # can fetch the data all in one go. + self.serialisation_safe = all( + (column.__class__.value_type in safe_types) + and (type(column) not in (JSON, JSONB)) + for column in columns + ) + + def get_select_string( + self, engine_type: str, with_alias=True + ) -> QueryString: + m2m_table_name_with_schema = ( + self.m2m._meta.resolved_joining_table._meta.get_formatted_tablename() # noqa: E501 + ) # noqa: E501 + m2m_relationship_name = self.m2m._meta.name + + fk_1 = self.m2m._meta.primary_foreign_key + fk_1_name = fk_1._meta.db_column_name + table_1 = fk_1._foreign_key_meta.resolved_references + table_1_name = table_1._meta.tablename + table_1_name_with_schema = table_1._meta.get_formatted_tablename() + table_1_pk_name = table_1._meta.primary_key._meta.db_column_name + + fk_2 = self.m2m._meta.secondary_foreign_key + fk_2_name = fk_2._meta.db_column_name + table_2 = fk_2._foreign_key_meta.resolved_references + table_2_name = table_2._meta.tablename + table_2_name_with_schema = table_2._meta.get_formatted_tablename() + table_2_pk_name = table_2._meta.primary_key._meta.db_column_name + + inner_select = f""" + {m2m_table_name_with_schema} + JOIN {table_1_name_with_schema} "inner_{table_1_name}" ON ( + {m2m_table_name_with_schema}."{fk_1_name}" = "inner_{table_1_name}"."{table_1_pk_name}" + ) + JOIN {table_2_name_with_schema} "inner_{table_2_name}" ON ( + {m2m_table_name_with_schema}."{fk_2_name}" = "inner_{table_2_name}"."{table_2_pk_name}" + ) + WHERE {m2m_table_name_with_schema}."{fk_1_name}" = "{table_1_name}"."{table_1_pk_name}" + """ # noqa: E501 + + if engine_type in ("postgres", "cockroach"): + if self.as_list: + column_name = self.columns[0]._meta.db_column_name + return QueryString( + f""" + ARRAY( + SELECT + "inner_{table_2_name}"."{column_name}" + FROM {inner_select} + ) AS "{m2m_relationship_name}" + """ + ) + elif not self.serialisation_safe: + column_name = table_2_pk_name + return QueryString( + f""" + ARRAY( + SELECT + "inner_{table_2_name}"."{column_name}" + FROM {inner_select} + ) AS "{m2m_relationship_name}" + """ + ) + else: + column_names = ", ".join( + f'"inner_{table_2_name}"."{column._meta.db_column_name}"' + for column in self.columns + ) + return QueryString( + f""" + ( + SELECT JSON_AGG({m2m_relationship_name}_results) + FROM ( + SELECT {column_names} FROM {inner_select} + ) AS "{m2m_relationship_name}_results" + ) AS "{m2m_relationship_name}" + """ + ) + elif engine_type == "sqlite": + if len(self.columns) > 1 or not self.serialisation_safe: + column_name = table_2_pk_name + else: + assert len(self.columns) > 0 + column_name = self.columns[0]._meta.db_column_name + + return QueryString( + f""" + ( + SELECT group_concat( + "inner_{table_2_name}"."{column_name}" + ) + FROM {inner_select} + ) + AS "{m2m_relationship_name} [M2M]" + """ + ) + else: + raise ValueError(f"{engine_type} is an unrecognised engine type") + + +@dataclass +class M2MMeta: + joining_table: Union[type[Table], LazyTableReference] + _foreign_key_columns: Optional[list[ForeignKey]] = None + + # Set by the Table Metaclass: + _name: Optional[str] = None + _table: Optional[type[Table]] = None + + @property + def name(self) -> str: + if not self._name: + raise ValueError( + "`_name` isn't defined - the Table Metaclass should set it." + ) + return self._name + + @property + def table(self) -> type[Table]: + if not self._table: + raise ValueError( + "`_table` isn't defined - the Table Metaclass should set it." + ) + return self._table + + @property + def resolved_joining_table(self) -> type[Table]: + """ + Evaluates the ``joining_table`` attribute if it's a + ``LazyTableReference``, raising a ``ValueError`` if it fails, otherwise + returns a ``Table`` subclass. + """ + from piccolo.table import Table + + if isinstance(self.joining_table, LazyTableReference): + return self.joining_table.resolve() + elif inspect.isclass(self.joining_table) and issubclass( + self.joining_table, Table + ): + return self.joining_table + else: + raise ValueError( + "The joining_table attribute is neither a Table subclass or a " + "LazyTableReference instance." + ) + + @property + def foreign_key_columns(self) -> list[ForeignKey]: + if not self._foreign_key_columns: + self._foreign_key_columns = ( + self.resolved_joining_table._meta.foreign_key_columns[:2] + ) + return self._foreign_key_columns + + @property + def primary_foreign_key(self) -> ForeignKey: + """ + The joining table has two foreign keys. We need a way to distinguish + between them. The primary is the one which points to the table with + ``M2M`` defined on it. In this example the primary foreign key is the + one which points to ``Band``: + + .. code-block:: python + + class Band(Table): + name = Varchar() + genres = M2M( + LazyTableReference("GenreToBand", module_path=__name__) + ) + + class Genre(Table): + name = Varchar() + + class GenreToBand(Table): + band = ForeignKey(Band) # primary + genre = ForeignKey(Genre) # secondary + + The secondary foreign key is the one which points to ``Genre``. + + """ + for fk_column in self.foreign_key_columns: + if fk_column._foreign_key_meta.resolved_references == self.table: + return fk_column + + raise ValueError("No matching foreign key column found!") + + @property + def primary_table(self) -> type[Table]: + return self.primary_foreign_key._foreign_key_meta.resolved_references + + @property + def secondary_foreign_key(self) -> ForeignKey: + """ + See ``primary_foreign_key``. + """ + for fk_column in self.foreign_key_columns: + if fk_column._foreign_key_meta.resolved_references != self.table: + return fk_column + + raise ValueError("No matching foreign key column found!") + + @property + def secondary_table(self) -> type[Table]: + return self.secondary_foreign_key._foreign_key_meta.resolved_references + + +@dataclass +class M2MAddRelated: + target_row: Table + m2m: M2M + rows: Sequence[Table] + extra_column_values: dict[Union[Column, str], Any] + + @property + def resolved_extra_column_values(self) -> dict[str, Any]: + return { + i._meta.name if isinstance(i, Column) else i: j + for i, j in self.extra_column_values.items() + } + + async def _run(self): + rows = self.rows + unsaved = [i for i in rows if not i._exists_in_db] + + if unsaved: + await rows[0].__class__.insert(*unsaved).run() + + joining_table = self.m2m._meta.resolved_joining_table + + joining_table_rows = [] + + for row in rows: + joining_table_row = joining_table( + **self.resolved_extra_column_values + ) + setattr( + joining_table_row, + self.m2m._meta.primary_foreign_key._meta.name, + getattr( + self.target_row, + self.target_row._meta.primary_key._meta.name, + ), + ) + setattr( + joining_table_row, + self.m2m._meta.secondary_foreign_key._meta.name, + getattr( + row, + row._meta.primary_key._meta.name, + ), + ) + joining_table_rows.append(joining_table_row) + + return await joining_table.insert(*joining_table_rows).run() + + async def run(self): + """ + Run the queries, making sure they are either within an existing + transaction, or wrapped in a new transaction. + """ + engine = self.rows[0]._meta.db + async with engine.transaction(): + return await self._run() + + def run_sync(self): + return run_sync(self.run()) + + def __await__(self): + return self.run().__await__() + + +@dataclass +class M2MRemoveRelated: + target_row: Table + m2m: M2M + rows: Sequence[Table] + + async def run(self): + fk = self.m2m._meta.secondary_foreign_key + related_table = fk._foreign_key_meta.resolved_references + + row_ids = [] + + for row in self.rows: + if row.__class__ != related_table: + raise ValueError("The row belongs to the wrong table!") + + row_id = getattr(row, row._meta.primary_key._meta.name) + if row_id: + row_ids.append(row_id) + + if row_ids: + return ( + await self.m2m._meta.resolved_joining_table.delete() + .where( + self.m2m._meta.primary_foreign_key == self.target_row, + self.m2m._meta.secondary_foreign_key.is_in(row_ids), + ) + .run() + ) + + return None + + def run_sync(self): + return run_sync(self.run()) + + def __await__(self): + return self.run().__await__() + + +@dataclass +class M2MGetRelated: + row: Table + m2m: M2M + + async def run(self): + joining_table = self.m2m._meta.resolved_joining_table + + secondary_table = self.m2m._meta.secondary_table + + # use a subquery to make only one db query + results = await secondary_table.objects().where( + secondary_table._meta.primary_key.is_in( + joining_table.select( + getattr( + self.m2m._meta.secondary_foreign_key, + secondary_table._meta.primary_key._meta.name, + ) + ).where(self.m2m._meta.primary_foreign_key == self.row) + ) + ) + + return results + + def run_sync(self): + return run_sync(self.run()) + + def __await__(self): + return self.run().__await__() + + +class M2M: + def __init__( + self, + joining_table: Union[type[Table], LazyTableReference], + foreign_key_columns: Optional[list[ForeignKey]] = None, + ): + """ + :param joining_table: + A ``Table`` containing two ``ForeignKey`` columns. + :param foreign_key_columns: + If for some reason your joining table has more than two foreign key + columns, you can explicitly specify which two are relevant. + + """ + if foreign_key_columns and ( + len(foreign_key_columns) != 2 + or not all(isinstance(i, ForeignKey) for i in foreign_key_columns) + ): + raise ValueError("You must specify two ForeignKey columns.") + + self._meta = M2MMeta( + joining_table=joining_table, + _foreign_key_columns=foreign_key_columns, + ) + + def __call__( + self, + *columns: Union[Column, list[Column]], + as_list: bool = False, + load_json: bool = False, + ) -> M2MSelect: + """ + :param columns: + Which columns to include from the related table. If none are + specified, then all of the columns are returned. + :param as_list: + If a single column is provided, and ``as_list`` is ``True`` a + flattened list will be returned, rather than a list of objects. + :param load_json: + If ``True``, any JSON strings are loaded as Python objects. + """ + columns_ = flatten(columns) + + if not columns_: + columns_ = self._meta.secondary_table._meta.columns + + if as_list and len(columns_) != 1: + raise ValueError( + "`as_list` is only valid with a single column argument" + ) + + return M2MSelect( + *columns_, m2m=self, as_list=as_list, load_json=load_json + ) diff --git a/piccolo/columns/operators/__init__.py b/piccolo/columns/operators/__init__.py index ee0805499..603264170 100644 --- a/piccolo/columns/operators/__init__.py +++ b/piccolo/columns/operators/__init__.py @@ -16,4 +16,4 @@ NotLike, ) from .math import Add, Divide, Multiply, Subtract -from .string import ConcatPostgres, ConcatSQLite +from .string import Concat diff --git a/piccolo/columns/operators/comparison.py b/piccolo/columns/operators/comparison.py index 2811c3836..91b565361 100644 --- a/piccolo/columns/operators/comparison.py +++ b/piccolo/columns/operators/comparison.py @@ -62,5 +62,9 @@ class ArrayAny(ComparisonOperator): template = "{value} = ANY ({name})" +class ArrayNotAny(ComparisonOperator): + template = "NOT {value} = ANY ({name})" + + class ArrayAll(ComparisonOperator): template = "{value} = ALL ({name})" diff --git a/piccolo/columns/operators/string.py b/piccolo/columns/operators/string.py index 51b0c745f..1c979f94b 100644 --- a/piccolo/columns/operators/string.py +++ b/piccolo/columns/operators/string.py @@ -5,9 +5,5 @@ class StringOperator(Operator): pass -class ConcatPostgres(StringOperator): - template = "CONCAT({value_1}, {value_2})" - - -class ConcatSQLite(StringOperator): +class Concat(StringOperator): template = "{value_1} || {value_2}" diff --git a/piccolo/columns/readable.py b/piccolo/columns/readable.py index bf007e939..cd02c5c91 100644 --- a/piccolo/columns/readable.py +++ b/piccolo/columns/readable.py @@ -1,11 +1,12 @@ from __future__ import annotations -import typing as t +from collections.abc import Sequence from dataclasses import dataclass +from typing import TYPE_CHECKING -from piccolo.columns.base import Selectable +from piccolo.querystring import QueryString, Selectable -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from piccolo.columns.base import Column @@ -18,33 +19,39 @@ class Readable(Selectable): """ template: str - columns: t.Sequence[Column] + columns: Sequence[Column] output_name: str = "readable" @property def _columns_string(self) -> str: return ", ".join( - [i._meta.get_full_name(just_alias=True) for i in self.columns] + i._meta.get_full_name(with_alias=False) for i in self.columns ) - def _get_string(self, operator: str) -> str: - return ( + def _get_string(self, operator: str) -> QueryString: + return QueryString( f"{operator}('{self.template}', {self._columns_string}) AS " f"{self.output_name}" ) @property - def sqlite_string(self) -> str: + def sqlite_string(self) -> QueryString: return self._get_string(operator="PRINTF") @property - def postgres_string(self) -> str: + def postgres_string(self) -> QueryString: return self._get_string(operator="FORMAT") - def get_select_string(self, engine_type: str, just_alias=False) -> str: + @property + def cockroach_string(self) -> QueryString: + return self._get_string(operator="FORMAT") + + def get_select_string( + self, engine_type: str, with_alias=True + ) -> QueryString: try: return getattr(self, f"{engine_type}_string") - except AttributeError: + except AttributeError as e: raise ValueError( f"Unrecognised engine_type - received {engine_type}" - ) + ) from e diff --git a/piccolo/columns/reference.py b/piccolo/columns/reference.py index 167712ab2..532cf309d 100644 --- a/piccolo/columns/reference.py +++ b/piccolo/columns/reference.py @@ -1,14 +1,15 @@ """ Dataclasses for storing lazy references between ForeignKey columns and tables. """ + from __future__ import annotations import importlib import inspect -import typing as t from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Optional -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from piccolo.columns.column_types import ForeignKey from piccolo.table import Table @@ -16,23 +17,27 @@ @dataclass class LazyTableReference: """ - Holds a reference to a ``Table`` subclass. Used to avoid circular - dependencies in the ``references`` argument of ``ForeignKey`` columns. + Holds a reference to a :class:`Table ` subclass. Used + to avoid circular dependencies in the ``references`` argument of + :class:`ForeignKey ` columns. :param table_class_name: - The name of the ``Table`` subclass. For example, 'Manager'. + The name of the ``Table`` subclass. For example, ``'Manager'``. :param app_name: If specified, the ``Table`` subclass is imported from a Piccolo app with the given name. :param module_path: If specified, the ``Table`` subclass is imported from this path. - For example, 'my_app.tables'. + For example, ``'my_app.tables'``. + + .. hint:: + If the table is in the same file, you can pass in ``__name__``. """ table_class_name: str - app_name: t.Optional[str] = None - module_path: t.Optional[str] = None + app_name: Optional[str] = None + module_path: Optional[str] = None def __post_init__(self): if self.app_name is None and self.module_path is None: @@ -44,7 +49,7 @@ def __post_init__(self): "Specify either app_name or module_path - not both." ) - def resolve(self) -> t.Type[Table]: + def resolve(self) -> type[Table]: if self.app_name is not None: from piccolo.conf.apps import Finder @@ -55,7 +60,7 @@ def resolve(self) -> t.Type[Table]: if self.module_path: module = importlib.import_module(self.module_path) - table: t.Optional[t.Type[Table]] = getattr( + table: Optional[type[Table]] = getattr( module, self.table_class_name, None ) @@ -69,8 +74,8 @@ def resolve(self) -> t.Type[Table]: return table else: raise ValueError( - f"Can't find a Table subclass called {self.app_name} " - f"in {self.module_path}" + "Can't find a Table subclass called " + f"{self.table_class_name} in {self.module_path}" ) raise ValueError("You must specify either app_name or module_path.") @@ -86,9 +91,9 @@ def __str__(self): @dataclass class LazyColumnReferenceStore: - foreign_key_columns: t.List[ForeignKey] = field(default_factory=list) + foreign_key_columns: list[ForeignKey] = field(default_factory=list) - def for_table(self, table: t.Type[Table]) -> t.List[ForeignKey]: + def for_table(self, table: type[Table]) -> list[ForeignKey]: return [ i for i in self.foreign_key_columns @@ -96,7 +101,7 @@ def for_table(self, table: t.Type[Table]) -> t.List[ForeignKey]: and i._foreign_key_meta.references.resolve() is table ] - def for_tablename(self, tablename: str) -> t.List[ForeignKey]: + def for_tablename(self, tablename: str) -> list[ForeignKey]: return [ i for i in self.foreign_key_columns diff --git a/piccolo/conf/apps.py b/piccolo/conf/apps.py index 3464156a2..983e31e04 100644 --- a/piccolo/conf/apps.py +++ b/piccolo/conf/apps.py @@ -1,15 +1,22 @@ from __future__ import annotations -import functools +import ast import inspect import itertools import os +import pathlib import traceback -import typing as t +from abc import abstractmethod +from collections.abc import Callable, Sequence from dataclasses import dataclass, field +from graphlib import TopologicalSorter from importlib import import_module from types import ModuleType +from typing import Optional, Union, cast +import black + +from piccolo.apps.migrations.auto.migration_manager import MigrationManager from piccolo.engine.base import Engine from piccolo.table import Table from piccolo.utils.warnings import Level, colored_warning @@ -21,51 +28,91 @@ class MigrationModule(ModuleType): DESCRIPTION: str @staticmethod - async def forwards() -> None: - pass + @abstractmethod + async def forwards() -> MigrationManager: ... class PiccoloAppModule(ModuleType): APP_CONFIG: AppConfig +def get_package(name: str) -> str: + """ + :param name: + The __name__ variable from a Python file. + + """ + return ".".join(name.split(".")[:-1]) + + def table_finder( - modules: t.Sequence[str], - include_tags: t.Sequence[str] = ["__all__"], - exclude_tags: t.Sequence[str] = [], -) -> t.List[t.Type[Table]]: + modules: Sequence[str], + package: Optional[str] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + exclude_imported: bool = False, +) -> list[type[Table]]: """ Rather than explicitly importing and registering table classes with the - AppConfig, ``table_finder`` can be used instead. It imports any ``Table`` + ``AppConfig``, ``table_finder`` can be used instead. It imports any ``Table`` subclasses in the given modules. Tags can be used to limit which ``Table`` subclasses are imported. :param modules: The module paths to check for ``Table`` subclasses. For example, - ['blog.tables']. The path should be from the root of your project, - not a relative path. + ``['blog.tables']``. + :param package: + This must be passed in if the modules are relative paths (e.g. + if ``modules=['.tables']`` then ``package='blog'``). :param include_tags: If the ``Table`` subclass has one of these tags, it will be - imported. The special tag '__all__' will import all ``Table`` + imported. The special tag ``'__all__'`` will import all ``Table`` subclasses found. :param exclude_tags: If the ``Table`` subclass has any of these tags, it won't be - imported. `exclude_tags` overrides `include_tags`. - - """ + imported. ``exclude_tags`` overrides ``include_tags``. + :param exclude_imported: + If ``True``, only ``Table`` subclasses defined within the module are + used. Any ``Table`` subclasses imported by that module from other + modules are ignored. For example: + + .. code-block:: python + + from piccolo.table import Table + from piccolo.column import Varchar, ForeignKey + from piccolo.apps.user.tables import BaseUser # excluded + + class Task(Table): # included + title = Varchar() + creator = ForeignKey(BaseUser) + + """ # noqa: E501 + if include_tags is None: + include_tags = ["__all__"] + if exclude_tags is None: + exclude_tags = [] if isinstance(modules, str): # Guard against the user just entering a string, for example # 'blog.tables', instead of ['blog.tables']. modules = [modules] - table_subclasses: t.List[t.Type[Table]] = [] + table_subclasses: list[type[Table]] = [] for module_path in modules: + full_module_path = ( + ".".join([package, module_path.lstrip(".")]) + if package + else module_path + ) + try: - module = import_module(module_path) + module = import_module( + module_path, + package=package, + ) except ImportError as exception: - print(f"Unable to import {module_path}") - raise exception + print(f"Unable to import {full_module_path}") + raise exception from exception object_names = [i for i in dir(module) if not i.startswith("_")] @@ -76,7 +123,11 @@ def table_finder( and issubclass(_object, Table) and _object is not Table ): - table: Table = _object + table: Table = _object # type: ignore + + if exclude_imported and table.__module__ != full_module_path: + continue + if exclude_tags and set(table._meta.tags).intersection( set(exclude_tags) ): @@ -91,8 +142,19 @@ def table_finder( @dataclass class Command: - callable: t.Callable - aliases: t.List[str] = field(default_factory=list) + """ + :param callable: + The function or method to be called. + :param command_name: + If not specified, the name of the ``callable`` is used. + :param aliases: + Alternative ways to refer to this command in the CLI. + + """ + + callable: Callable + command_name: Optional[str] = None + aliases: list[str] = field(default_factory=list) @dataclass @@ -120,84 +182,113 @@ class AppConfig: """ app_name: str - migrations_folder_path: str - table_classes: t.List[t.Type[Table]] = field(default_factory=list) - migration_dependencies: t.List[str] = field(default_factory=list) - commands: t.List[t.Union[t.Callable, Command]] = field( - default_factory=list - ) - - def __post_init__(self): - self.commands = [ - i if isinstance(i, Command) else Command(i) for i in self.commands - ] + migrations_folder_path: Union[str, pathlib.Path] + table_classes: list[type[Table]] = field(default_factory=list) + migration_dependencies: list[str] = field(default_factory=list) + commands: list[Union[Callable, Command]] = field(default_factory=list) + + @property + def resolved_migrations_folder_path(self) -> str: + return ( + str(self.migrations_folder_path) + if isinstance(self.migrations_folder_path, pathlib.Path) + else self.migrations_folder_path + ) - def register_table(self, table_class: t.Type[Table]): + def __post_init__(self) -> None: + self._migration_dependency_app_configs: Optional[list[AppConfig]] = ( + None + ) + + def register_table(self, table_class: type[Table]): self.table_classes.append(table_class) return table_class + def get_commands(self) -> list[Command]: + return [ + i if isinstance(i, Command) else Command(i) for i in self.commands + ] + @property - def migration_dependency_app_configs(self) -> t.List[AppConfig]: + def migration_dependency_app_configs(self) -> list[AppConfig]: """ - Get all of the AppConfig instances from this app's migration + Get all of the ``AppConfig`` instances from this app's migration dependencies. """ - modules: t.List[PiccoloAppModule] = [ - t.cast(PiccoloAppModule, import_module(module_path)) - for module_path in self.migration_dependencies - ] - return [i.APP_CONFIG for i in modules] - - def get_table_with_name(self, table_class_name: str) -> t.Type[Table]: + # We cache the value so it's more efficient, and also so we can set the + # underlying value in unit tests for easier mocking. + if self._migration_dependency_app_configs is None: + modules: list[PiccoloAppModule] = [ + cast(PiccoloAppModule, import_module(module_path)) + for module_path in self.migration_dependencies + ] + self._migration_dependency_app_configs = [ + i.APP_CONFIG for i in modules + ] + + return self._migration_dependency_app_configs + + def get_table_with_name(self, table_class_name: str) -> type[Table]: """ - Returns a Table subclass with the given name from this app, if it - exists. Otherwise raises a ValueError. + Returns a ``Table`` subclass with the given name from this app, if it + exists. Otherwise raises a ``ValueError``. """ filtered = [ table_class for table_class in self.table_classes if table_class.__name__ == table_class_name ] - if len(filtered) == 0: + if not filtered: raise ValueError( f"No table with class name {table_class_name} exists." ) return filtered[0] -@dataclass class AppRegistry: """ - Records all of the Piccolo apps in your project. Kept in piccolo_conf.py. + Records all of the Piccolo apps in your project. Kept in + ``piccolo_conf.py``. :param apps: - A list of paths to Piccolo apps, e.g. ['blog.piccolo_app'] + A list of paths to Piccolo apps, e.g. ``['blog.piccolo_app']``. """ - apps: t.List[str] = field(default_factory=list) - - def __post_init__(self): - self.app_configs: t.Dict[str, AppConfig] = {} + def __init__(self, apps: Optional[list[str]] = None): + self.apps = apps or [] + self.app_configs: dict[str, AppConfig] = {} app_names = [] for app in self.apps: - app_conf_module = import_module(app) - app_config: AppConfig = getattr(app_conf_module, "APP_CONFIG") + try: + app_conf_module = import_module(app) + app_config: AppConfig = getattr(app_conf_module, "APP_CONFIG") + except (ImportError, AttributeError) as e: + if app.endswith(".piccolo_app"): + raise e from e + app += ".piccolo_app" + app_conf_module = import_module(app) + app_config = getattr(app_conf_module, "APP_CONFIG") + colored_warning( + f"App {app[:-12]} should end with `.piccolo_app`", + level=Level.medium, + ) + self.app_configs[app_config.app_name] = app_config app_names.append(app_config.app_name) self._validate_app_names(app_names) @staticmethod - def _validate_app_names(app_names: t.List[str]): + def _validate_app_names(app_names: list[str]): """ Raise a ValueError if an app_name is repeated. """ app_names.sort() grouped = itertools.groupby(app_names) for key, value in grouped: - count = len([i for i in value]) + count = len(list(value)) if count > 1: raise ValueError( f"There are {count} apps with the name `{key}`. This can " @@ -206,10 +297,10 @@ def _validate_app_names(app_names: t.List[str]): "multiple times." ) - def get_app_config(self, app_name: str) -> t.Optional[AppConfig]: + def get_app_config(self, app_name: str) -> Optional[AppConfig]: return self.app_configs.get(app_name) - def get_table_classes(self, app_name: str) -> t.List[t.Type[Table]]: + def get_table_classes(self, app_name: str) -> list[type[Table]]: """ Returns each Table subclass defined in the given app if it exists. Otherwise raises a ValueError. @@ -225,7 +316,7 @@ def get_table_classes(self, app_name: str) -> t.List[t.Type[Table]]: def get_table_with_name( self, app_name: str, table_class_name: str - ) -> t.Optional[t.Type[Table]]: + ) -> Optional[type[Table]]: """ Returns a Table subclass registered with the given app if it exists. Otherwise raises a ValueError. @@ -265,17 +356,17 @@ def __init__(self, diagnose: bool = False): self.diagnose = diagnose def _deduplicate( - self, config_modules: t.List[PiccoloAppModule] - ) -> t.List[PiccoloAppModule]: + self, config_modules: list[PiccoloAppModule] + ) -> list[PiccoloAppModule]: """ Remove all duplicates - just leaving the first instance. """ # Deduplicate, but preserve order - which is why set() isn't used. - return list(dict([(c, None) for c in config_modules]).keys()) + return list({c: None for c in config_modules}.keys()) def _import_app_modules( - self, config_module_paths: t.List[str] - ) -> t.List[PiccoloAppModule]: + self, config_module_paths: list[str] + ) -> list[PiccoloAppModule]: """ Import all piccolo_app.py modules within your apps, and all dependencies. @@ -284,11 +375,13 @@ def _import_app_modules( for config_module_path in config_module_paths: try: - config_module = t.cast( + config_module = cast( PiccoloAppModule, import_module(config_module_path) ) - except ImportError: - raise Exception(f"Unable to import {config_module_path}") + except ImportError as e: + raise Exception( + f"Unable to import {config_module_path}" + ) from e app_config: AppConfig = getattr(config_module, "APP_CONFIG") dependency_config_modules = self._import_app_modules( app_config.migration_dependencies @@ -298,14 +391,14 @@ def _import_app_modules( return config_modules def get_piccolo_conf_module( - self, module_name: t.Optional[str] = None - ) -> t.Optional[PiccoloConfModule]: + self, module_name: Optional[str] = None + ) -> Optional[PiccoloConfModule]: """ Searches the path for a 'piccolo_conf.py' module to import. The location searched can be overriden by: - * Explicitly passing a module name into this method. - * Setting the PICCOLO_CONF environment variable. + * Explicitly passing a module name into this method. + * Setting the PICCOLO_CONF environment variable. An example override is 'my_folder.piccolo_conf'. @@ -319,7 +412,7 @@ def get_piccolo_conf_module( module_name = DEFAULT_MODULE_NAME try: - module = t.cast(PiccoloConfModule, import_module(module_name)) + module = cast(PiccoloConfModule, import_module(module_name)) except ModuleNotFoundError as exc: if self.diagnose: colored_warning( @@ -335,31 +428,40 @@ def get_piccolo_conf_module( raise ModuleNotFoundError( "PostgreSQL driver not found. " "Try running `pip install 'piccolo[postgres]'`" - ) + ) from exc elif str(exc) == "No module named 'aiosqlite'": raise ModuleNotFoundError( "SQLite driver not found. " "Try running `pip install 'piccolo[sqlite]'`" - ) + ) from exc else: - raise exc + raise exc from exc else: return module + def get_piccolo_conf_path(self) -> str: + piccolo_conf_module = self.get_piccolo_conf_module() + + if piccolo_conf_module is None: + raise ModuleNotFoundError("piccolo_conf.py not found.") + + module_file_path = piccolo_conf_module.__file__ + assert module_file_path + + return module_file_path + def get_app_registry(self) -> AppRegistry: """ - Returns the AppRegistry instance within piccolo_conf. + Returns the ``AppRegistry`` instance within piccolo_conf. """ piccolo_conf_module = self.get_piccolo_conf_module() - app_registry = getattr(piccolo_conf_module, "APP_REGISTRY") - return app_registry + return getattr(piccolo_conf_module, "APP_REGISTRY") def get_engine( - self, module_name: t.Optional[str] = None - ) -> t.Optional[Engine]: + self, module_name: Optional[str] = None + ) -> Optional[Engine]: piccolo_conf = self.get_piccolo_conf_module(module_name=module_name) - engine: t.Optional[Engine] = None - engine = getattr(piccolo_conf, ENGINE_VAR, None) + engine: Optional[Engine] = getattr(piccolo_conf, ENGINE_VAR, None) if not engine: colored_warning( @@ -375,10 +477,10 @@ def get_engine( return engine - def get_app_modules(self) -> t.List[PiccoloAppModule]: + def get_app_modules(self) -> list[PiccoloAppModule]: """ - Returns the piccolo_app.py modules for each registered Piccolo app in - your project. + Returns the ``piccolo_app.py`` modules for each registered Piccolo app + in your project. """ app_registry = self.get_app_registry() app_modules = self._import_app_modules(app_registry.apps) @@ -388,43 +490,194 @@ def get_app_modules(self) -> t.List[PiccoloAppModule]: return app_modules - def get_sorted_app_names(self) -> t.List[str]: - """ - Sorts the app names using the migration dependencies, so dependencies - are before dependents in the list. + def get_app_names( + self, sort_by_migration_dependencies: bool = True + ) -> list[str]: """ - modules = self.get_app_modules() - configs: t.List[AppConfig] = [module.APP_CONFIG for module in modules] + Return all of the app names. + + :param sort_by_migration_dependencies: + If True, sorts the app names using the migration dependencies, so + dependencies are before dependents in the list. - def sort_app_configs(app_config_1: AppConfig, app_config_2: AppConfig): - return ( - app_config_1 in app_config_2.migration_dependency_app_configs + """ + return [ + i.app_name + for i in self.get_app_configs( + sort_by_migration_dependencies=sort_by_migration_dependencies ) + ] - sorted_configs = sorted( - configs, key=functools.cmp_to_key(sort_app_configs) + def get_sorted_app_names(self) -> list[str]: + """ + Just here for backwards compatibility - use ``get_app_names`` directly. + """ + return self.get_app_names(sort_by_migration_dependencies=True) + + def sort_app_configs( + self, app_configs: list[AppConfig] + ) -> list[AppConfig]: + app_config_map = { + app_config.app_name: app_config for app_config in app_configs + } + + sorted_app_names = TopologicalSorter( + { + app_config.app_name: [ + i.app_name + for i in app_config.migration_dependency_app_configs + ] + for app_config in app_config_map.values() + } + ).static_order() + + return [app_config_map[i] for i in sorted_app_names] + + def get_app_configs( + self, sort_by_migration_dependencies: bool = True + ) -> list[AppConfig]: + """ + Returns a list of ``AppConfig``, optionally sorted by migration + dependencies. + """ + app_configs = [i.APP_CONFIG for i in self.get_app_modules()] + + return ( + self.sort_app_configs(app_configs=app_configs) + if sort_by_migration_dependencies + else app_configs ) - return [i.app_name for i in sorted_configs] def get_app_config(self, app_name: str) -> AppConfig: """ - Returns an `AppConfig` for the given app name. + Returns an ``AppConfig`` for the given app name. """ - modules = self.get_app_modules() - for module in modules: - app_config = module.APP_CONFIG + for app_config in self.get_app_configs(): if app_config.app_name == app_name: return app_config raise ValueError(f"No app found with name {app_name}") def get_table_with_name( self, app_name: str, table_class_name: str - ) -> t.Type[Table]: + ) -> type[Table]: """ - Returns a Table subclass registered with the given app if it exists. - Otherwise it raises an ValueError. + Returns a ``Table`` class registered with the given app if it exists. + Otherwise it raises an ``ValueError``. """ app_config = self.get_app_config(app_name=app_name) return app_config.get_table_with_name( table_class_name=table_class_name ) + + def get_table_classes( + self, + include_apps: Optional[list[str]] = None, + exclude_apps: Optional[list[str]] = None, + ) -> list[type[Table]]: + """ + Returns all ``Table`` classes registered with the given apps. If + ``include_apps`` is ``None``, then ``Table`` classes will be returned + for all apps. + """ + if include_apps and exclude_apps: + raise ValueError("Only specify `include_apps` or `exclude_apps`.") + + if include_apps: + app_names = include_apps + else: + app_names = self.get_app_names() + if exclude_apps: + app_names = [i for i in app_names if i not in exclude_apps] + + tables: list[type[Table]] = [] + + for app_name in app_names: + app_config = self.get_app_config(app_name=app_name) + tables.extend(app_config.table_classes) + + return tables + + +############################################################################### + + +class PiccoloConfUpdater: + + def __init__(self, piccolo_conf_path: Optional[str] = None): + """ + :param piccolo_conf_path: + The path to the piccolo_conf.py (e.g. `./piccolo_conf.py`). If not + passed in, we use our ``Finder`` class to get it. + """ + self.piccolo_conf_path = ( + piccolo_conf_path or Finder().get_piccolo_conf_path() + ) + + def _modify_app_registry_src(self, src: str, app_module: str) -> str: + """ + :param src: + The contents of the ``piccolo_conf.py`` file. + :param app_module: + The app to add to the registry e.g. ``'music.piccolo_app'``. + :returns: + Updated Python source code string. + + """ + ast_root = ast.parse(src) + + parsing_successful = False + + for node in ast.walk(ast_root): + if isinstance(node, ast.Call): + if ( + isinstance(node.func, ast.Name) + and node.func.id == "AppRegistry" + ): + if len(node.keywords) > 0: + keyword = node.keywords[0] + if keyword.arg == "apps": + apps = keyword.value + if isinstance(apps, ast.List): + apps.elts.append( + ast.Constant(app_module, kind="str") + ) + parsing_successful = True + break + + if not parsing_successful: + raise SyntaxError( + "Unable to parse piccolo_conf.py - `AppRegistry(apps=...)` " + "not found)." + ) + + new_contents = ast.unparse(ast_root) + + formatted_contents = black.format_str( + new_contents, mode=black.FileMode(line_length=80) + ) + + return formatted_contents + + def register_app(self, app_module: str): + """ + Adds the given app to the ``AppRegistry`` in ``piccolo_conf.py``. + + This is used by command line tools like: + + .. code-block:: bash + + piccolo app new my_app --register + + :param app_module: + The module of the app, e.g. ``'music.piccolo_app'``. + + """ + with open(self.piccolo_conf_path) as f: + piccolo_conf_src = f.read() + + new_contents = self._modify_app_registry_src( + src=piccolo_conf_src, app_module=app_module + ) + + with open(self.piccolo_conf_path, "wt") as f: + f.write(new_contents) diff --git a/piccolo/custom_types.py b/piccolo/custom_types.py index 0efd1b3d3..49f078076 100644 --- a/piccolo/custom_types.py +++ b/piccolo/custom_types.py @@ -1,14 +1,41 @@ from __future__ import annotations -import typing as t +import datetime +import decimal +import uuid +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, TypeVar, Union -if t.TYPE_CHECKING: # pragma: no cover +from typing_extensions import TypeAlias + +if TYPE_CHECKING: # pragma: no cover from piccolo.columns.combination import And, Or, Where, WhereRaw # noqa + from piccolo.table import Table + + +Combinable = Union["Where", "WhereRaw", "And", "Or"] +CustomIterable = Iterable[Any] + +TableInstance = TypeVar("TableInstance", bound="Table") +QueryResponseType = TypeVar("QueryResponseType", bound=Any) -Combinable = t.Union["Where", "WhereRaw", "And", "Or"] -Iterable = t.Iterable[t.Any] +# These are types we can reasonably expect to send to the database. +BasicTypes: TypeAlias = Union[ + bytes, + datetime.date, + datetime.datetime, + datetime.time, + datetime.timedelta, + decimal.Decimal, + dict, + float, + int, + list, + str, + uuid.UUID, +] ############################################################################### # For backwards compatibility: diff --git a/piccolo/engine/__init__.py b/piccolo/engine/__init__.py index 2cae3a4db..eb050f5e6 100644 --- a/piccolo/engine/__init__.py +++ b/piccolo/engine/__init__.py @@ -1,6 +1,13 @@ from .base import Engine +from .cockroach import CockroachEngine from .finder import engine_finder from .postgres import PostgresEngine from .sqlite import SQLiteEngine -__all__ = ["Engine", "PostgresEngine", "SQLiteEngine", "engine_finder"] +__all__ = [ + "Engine", + "PostgresEngine", + "SQLiteEngine", + "CockroachEngine", + "engine_finder", +] diff --git a/piccolo/engine/base.py b/piccolo/engine/base.py index 3f20ed0d8..b86af9dd2 100644 --- a/piccolo/engine/base.py +++ b/piccolo/engine/base.py @@ -1,49 +1,154 @@ from __future__ import annotations -import typing as t +import contextvars +import logging +import pprint +import string from abc import ABCMeta, abstractmethod +from typing import TYPE_CHECKING, Final, Generic, Optional, TypeVar, Union + +from typing_extensions import Self from piccolo.querystring import QueryString from piccolo.utils.sync import run_sync -from piccolo.utils.warnings import Level, colored_warning +from piccolo.utils.warnings import Level, colored_string, colored_warning + +if TYPE_CHECKING: # pragma: no cover + from piccolo.query.base import DDL, Query + + +logger = logging.getLogger(__name__) +# This is a set to speed up lookups from O(n) when +# using str vs O(1) when using set[str] +VALID_SAVEPOINT_CHARACTERS: Final[set[str]] = set( + string.ascii_letters + string.digits + "-" + "_" +) + + +def validate_savepoint_name(savepoint_name: str) -> None: + """Validates a save point's name meets the required character set.""" + if not all(i in VALID_SAVEPOINT_CHARACTERS for i in savepoint_name): + raise ValueError( + "Savepoint names can only contain the following characters:" + f" {VALID_SAVEPOINT_CHARACTERS}" + ) + + +class BaseBatch(metaclass=ABCMeta): + @abstractmethod + async def __aenter__(self: Self, *args, **kwargs) -> Self: ... + + @abstractmethod + async def __aexit__(self, *args, **kwargs): ... + + @abstractmethod + def __aiter__(self: Self) -> Self: ... + + @abstractmethod + async def __anext__(self) -> list[dict]: ... + + +class BaseTransaction(metaclass=ABCMeta): + + __slots__: tuple[str, ...] = tuple() + + @abstractmethod + async def __aenter__(self, *args, **kwargs): ... + + @abstractmethod + async def __aexit__(self, *args, **kwargs) -> bool: ... + + +class BaseAtomic(metaclass=ABCMeta): + + __slots__: tuple[str, ...] = tuple() + + @abstractmethod + def add(self, *query: Union[Query, DDL]): ... + + @abstractmethod + async def run(self): ... + + @abstractmethod + def run_sync(self): ... + + @abstractmethod + def __await__(self): ... -if t.TYPE_CHECKING: # pragma: no cover - from piccolo.query.base import Query +TransactionClass = TypeVar("TransactionClass", bound=BaseTransaction) -class Batch: - pass +class Engine(Generic[TransactionClass], metaclass=ABCMeta): + __slots__ = ( + "query_id", + "log_queries", + "log_responses", + "engine_type", + "min_version_number", + "current_transaction", + ) + + def __init__( + self, + engine_type: str, + min_version_number: Union[int, float], + log_queries: bool = False, + log_responses: bool = False, + ): + self.log_queries = log_queries + self.log_responses = log_responses + self.engine_type = engine_type + self.min_version_number = min_version_number -class Engine(metaclass=ABCMeta): - def __init__(self): run_sync(self.check_version()) run_sync(self.prep_database()) + self.query_id = 0 - @property @abstractmethod - def engine_type(self) -> str: + async def get_version(self) -> float: pass - @property @abstractmethod - def min_version_number(self) -> float: + def get_version_sync(self) -> float: pass @abstractmethod - async def get_version(self) -> float: + async def prep_database(self): pass @abstractmethod - async def prep_database(self): + async def batch( + self, + query: Query, + batch_size: int = 100, + node: Optional[str] = None, + ) -> BaseBatch: + pass + + @abstractmethod + async def run_querystring( + self, querystring: QueryString, in_pool: bool = True + ): + pass + + def transform_response_to_dicts(self, results) -> list[dict]: + """ + If the database adapter returns something other than a list of + dictionaries, it should perform the transformation here. + """ + return results + + @abstractmethod + async def run_ddl(self, ddl: str, in_pool: bool = True): pass @abstractmethod - async def batch(self, query: Query, batch_size: int = 100) -> Batch: + def transaction(self, *args, **kwargs) -> TransactionClass: pass @abstractmethod - async def run_querystring(self, querystring: QueryString, in_pool: bool): + def atomic(self) -> BaseAtomic: pass async def check_version(self): @@ -60,8 +165,8 @@ async def check_version(self): return engine_type = self.engine_type.capitalize() - print(f"Running {engine_type} version {version_number}") - if version_number < self.min_version_number: + logger.info(f"Running {engine_type} version {version_number}") + if version_number and (version_number < self.min_version_number): message = ( f"This version of {self.engine_type} isn't supported " f"(< {self.min_version_number}) - some features might not be " @@ -69,3 +174,56 @@ async def check_version(self): "Piccolo docs." ) colored_warning(message, stacklevel=3) + + def _connection_pool_warning(self): + message = ( + f"Connection pooling is not supported for {self.engine_type}." + ) + logger.warning(message) + colored_warning(message, stacklevel=3) + + async def start_connection_pool(self): + """ + The database driver doesn't implement connection pooling. + """ + self._connection_pool_warning() + + async def close_connection_pool(self): + """ + The database driver doesn't implement connection pooling. + """ + self._connection_pool_warning() + + ########################################################################### + + current_transaction: contextvars.ContextVar[Optional[TransactionClass]] + + def transaction_exists(self) -> bool: + """ + Find out if a transaction is currently active. + + :returns: + ``True`` if a transaction is already active for the current + asyncio task. This is useful to know, because nested transactions + aren't currently supported, so you can check if an existing + transaction is already active, before creating a new one. + + """ + return self.current_transaction.get() is not None + + ########################################################################### + # Logging queries and responses + + def get_query_id(self) -> int: + self.query_id += 1 + return self.query_id + + def print_query(self, query_id: int, query: str): + print(colored_string(f"\nQuery {query_id}:")) + print(query) + + def print_response(self, query_id: int, response: list): + print( + colored_string(f"\nQuery {query_id} response:", level=Level.high) + ) + pprint.pprint(response) diff --git a/piccolo/engine/cockroach.py b/piccolo/engine/cockroach.py new file mode 100644 index 000000000..e823ced5a --- /dev/null +++ b/piccolo/engine/cockroach.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any, Optional + +from piccolo.utils.lazy_loader import LazyLoader +from piccolo.utils.warnings import Level, colored_warning + +from .postgres import PostgresEngine + +asyncpg = LazyLoader("asyncpg", globals(), "asyncpg") + + +class CockroachEngine(PostgresEngine): + """ + An extension of + :class:`PostgresEngine `. + """ + + def __init__( + self, + config: dict[str, Any], + extensions: Sequence[str] = (), + log_queries: bool = False, + log_responses: bool = False, + extra_nodes: Optional[dict[str, CockroachEngine]] = None, + ) -> None: + super().__init__( + config=config, + extensions=extensions, + log_queries=log_queries, + log_responses=log_responses, + extra_nodes=extra_nodes, + ) + self.engine_type = "cockroach" + self.min_version_number = 0 + + async def prep_database(self): + try: + await self._run_in_new_connection( + "SET CLUSTER SETTING sql.defaults.experimental_alter_column_type.enabled = true;" # noqa: E501 + ) + except asyncpg.exceptions.InsufficientPrivilegeError: + colored_warning( + "=> Unable to set up Cockroach DB " + "functionality may not behave as expected. Make sure " + "your database user has permission to set cluster options.", + level=Level.medium, + ) diff --git a/piccolo/engine/finder.py b/piccolo/engine/finder.py index e67accc93..a72db979a 100644 --- a/piccolo/engine/finder.py +++ b/piccolo/engine/finder.py @@ -1,11 +1,11 @@ from __future__ import annotations -import typing as t +from typing import Optional from piccolo.engine.base import Engine -def engine_finder(module_name: t.Optional[str] = None) -> t.Optional[Engine]: +def engine_finder(module_name: Optional[str] = None) -> Optional[Engine]: """ An example module name is `my_piccolo_conf`. diff --git a/piccolo/engine/postgres.py b/piccolo/engine/postgres.py index 7d815bd0a..f4ac2d17e 100644 --- a/piccolo/engine/postgres.py +++ b/piccolo/engine/postgres.py @@ -1,35 +1,44 @@ from __future__ import annotations import contextvars -import typing as t +from collections.abc import Sequence from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Mapping, Optional, Union -from piccolo.engine.base import Batch, Engine +from typing_extensions import Self + +from piccolo.engine.base import ( + BaseAtomic, + BaseBatch, + BaseTransaction, + Engine, + validate_savepoint_name, +) from piccolo.engine.exceptions import TransactionError -from piccolo.query.base import Query +from piccolo.query.base import DDL, Query from piccolo.querystring import QueryString from piccolo.utils.lazy_loader import LazyLoader from piccolo.utils.sync import run_sync -from piccolo.utils.warnings import Level, colored_string, colored_warning +from piccolo.utils.warnings import Level, colored_warning asyncpg = LazyLoader("asyncpg", globals(), "asyncpg") -if t.TYPE_CHECKING: # pragma: no cover - from asyncpg.connection import Connection # type: ignore - from asyncpg.cursor import Cursor # type: ignore - from asyncpg.pool import Pool # type: ignore +if TYPE_CHECKING: # pragma: no cover + from asyncpg.connection import Connection + from asyncpg.cursor import Cursor + from asyncpg.pool import Pool + from asyncpg.transaction import Transaction @dataclass -class AsyncBatch(Batch): - +class AsyncBatch(BaseBatch): connection: Connection query: Query batch_size: int # Set internally - _transaction = None - _cursor: t.Optional[Cursor] = None + _transaction: Optional[Transaction] = None + _cursor: Optional[Cursor] = None @property def cursor(self) -> Cursor: @@ -37,20 +46,26 @@ def cursor(self) -> Cursor: raise ValueError("_cursor not set") return self._cursor - async def next(self) -> t.List[t.Dict]: + @property + def transaction(self) -> Transaction: + if not self._transaction: + raise ValueError("The transaction can't be found.") + return self._transaction + + async def next(self) -> list[dict]: data = await self.cursor.fetch(self.batch_size) return await self.query._process_results(data) - def __aiter__(self): + def __aiter__(self: Self) -> Self: return self - async def __anext__(self): + async def __anext__(self) -> list[dict]: response = await self.next() if response == []: raise StopAsyncIteration() return response - async def __aenter__(self): + async def __aenter__(self: Self) -> Self: self._transaction = self.connection.transaction() await self._transaction.start() querystring = self.query.querystrings[0] @@ -61,9 +76,9 @@ async def __aenter__(self): async def __aexit__(self, exception_type, exception, traceback): if exception: - await self._transaction.rollback() + await self.transaction.rollback() else: - await self._transaction.commit() + await self.transaction.commit() await self.connection.close() @@ -73,120 +88,196 @@ async def __aexit__(self, exception_type, exception, traceback): ############################################################################### -class Atomic: +class Atomic(BaseAtomic): """ This is useful if you want to build up a transaction programatically, by adding queries to it. - Usage: + Usage:: - transaction = engine.atomic() - transaction.add(Foo.create_table()) + transaction = engine.atomic() + transaction.add(Foo.create_table()) + + # Either: + transaction.run_sync() + await transaction.run() - # Either: - transaction.run_sync() - await transaction.run() """ __slots__ = ("engine", "queries") def __init__(self, engine: PostgresEngine): self.engine = engine - self.queries: t.List[Query] = [] + self.queries: list[Union[Query, DDL]] = [] - def add(self, *query: Query): + def add(self, *query: Union[Query, DDL]): self.queries += list(query) - async def _run_queries(self, connection): - async with connection.transaction(): - for query in self.queries: - for querystring in query.querystrings: - _query, args = querystring.compile_string( - engine_type=self.engine.engine_type - ) - await connection.execute(_query, *args) - - self.queries = [] - - async def _run_in_pool(self): - pool = await self.engine.get_pool() - connection = await pool.acquire() + async def run(self): + from piccolo.query.methods.objects import Create, GetOrCreate try: - await self._run_queries(connection) - except Exception: - pass - finally: - await pool.release(connection) - - async def _run_in_new_connection(self): - connection = await asyncpg.connect(**self.engine.config) - await self._run_queries(connection) - - async def run(self, in_pool=True): - if in_pool and self.engine.pool: - await self._run_in_pool() - else: - await self._run_in_new_connection() + async with self.engine.transaction(): + for query in self.queries: + if isinstance(query, (Query, DDL, Create, GetOrCreate)): + await query.run() + else: + raise ValueError("Unrecognised query") + self.queries = [] + except Exception as exception: + self.queries = [] + raise exception from exception def run_sync(self): - return run_sync(self._run_in_new_connection()) + return run_sync(self.run()) + + def __await__(self): + return self.run().__await__() ############################################################################### -class Transaction: +class Savepoint: + def __init__(self, name: str, transaction: PostgresTransaction): + self.name = name + self.transaction = transaction + + async def rollback_to(self): + validate_savepoint_name(self.name) + await self.transaction.connection.execute( + f"ROLLBACK TO SAVEPOINT {self.name}" + ) + + async def release(self): + validate_savepoint_name(self.name) + await self.transaction.connection.execute( + f"RELEASE SAVEPOINT {self.name}" + ) + + +class PostgresTransaction(BaseTransaction): """ Used for wrapping queries in a transaction, using a context manager. Currently it's async only. - Usage: + Usage:: - async with engine.transaction(): - # Run some queries: - await Band.select().run() + async with engine.transaction(): + # Run some queries: + await Band.select().run() """ - __slots__ = ("engine", "transaction", "context", "connection") + __slots__ = ( + "engine", + "transaction", + "context", + "connection", + "_savepoint_id", + "_parent", + "_committed", + "_rolled_back", + ) + + def __init__(self, engine: PostgresEngine, allow_nested: bool = True): + """ + :param allow_nested: + If ``True`` then if we try creating a new transaction when another + is already active, we treat this as a no-op:: - def __init__(self, engine: PostgresEngine): + async with DB.transaction(): + async with DB.transaction(): + pass + + If we want to disallow this behaviour, then setting + ``allow_nested=False`` will cause a ``TransactionError`` to be + raised. + + """ self.engine = engine - if self.engine.transaction_connection.get(): - raise TransactionError( - "A transaction is already active - nested transactions aren't " - "currently supported." - ) + current_transaction = self.engine.current_transaction.get() + + self._savepoint_id = 0 + self._parent = None + self._committed = False + self._rolled_back = False + + if current_transaction: + if allow_nested: + self._parent = current_transaction + else: + raise TransactionError( + "A transaction is already active - nested transactions " + "aren't allowed." + ) + + async def __aenter__(self) -> PostgresTransaction: + if self._parent is not None: + return self._parent - async def __aenter__(self): + self.connection = await self.get_connection() + self.transaction = self.connection.transaction() + await self.begin() + self.context = self.engine.current_transaction.set(self) + return self + + async def get_connection(self): if self.engine.pool: - self.connection = await self.engine.pool.acquire() + return await self.engine.pool.acquire() else: - self.connection = await self.engine.get_new_connection() + return await self.engine.get_new_connection() - self.transaction = self.connection.transaction() + async def begin(self): await self.transaction.start() - self.context = self.engine.transaction_connection.set(self.connection) async def commit(self): await self.transaction.commit() + self._committed = True async def rollback(self): await self.transaction.rollback() + self._rolled_back = True + + async def rollback_to(self, savepoint_name: str): + """ + Used to rollback to a savepoint just using the name. + """ + await Savepoint(name=savepoint_name, transaction=self).rollback_to() + + ########################################################################### + + def get_savepoint_id(self) -> int: + self._savepoint_id += 1 + return self._savepoint_id + + async def savepoint(self, name: Optional[str] = None) -> Savepoint: + name = name or f"savepoint_{self.get_savepoint_id()}" + validate_savepoint_name(name) + await self.connection.execute(f"SAVEPOINT {name}") + return Savepoint(name=name, transaction=self) + + ########################################################################### + + async def __aexit__(self, exception_type, exception, traceback) -> bool: + if self._parent: + return exception is None - async def __aexit__(self, exception_type, exception, traceback): if exception: - await self.rollback() + # The user may have manually rolled it back. + if not self._rolled_back: + await self.rollback() else: - await self.commit() + # The user may have manually committed it. + if not self._committed and not self._rolled_back: + await self.commit() if self.engine.pool: await self.engine.pool.release(self.connection) else: await self.connection.close() - self.engine.transaction_connection.reset(self.context) + self.engine.current_transaction.reset(self.context) return exception is None @@ -194,56 +285,99 @@ async def __aexit__(self, exception_type, exception, traceback): ############################################################################### -class PostgresEngine(Engine): +class PostgresEngine(Engine[PostgresTransaction]): """ - Used to connect to Postgresql. + Used to connect to PostgreSQL. :param config: The config dictionary is passed to the underlying database adapter, asyncpg. Common arguments you're likely to need are: - * host - * port - * user - * password - * database + * host + * port + * user + * password + * database For example, ``{'host': 'localhost', 'port': 5432}``. - To see all available options: - - * https://magicstack.github.io/asyncpg/current/api/index.html#connection + See the `asyncpg docs `_ + for all available options. :param extensions: When the engine starts, it will try and create these extensions - in Postgres. + in Postgres. If you're using a read only database, set this value to an + empty tuple ``()``. :param log_queries: - If True, all SQL and DDL statements are printed out before being run. - Useful for debugging. + If ``True``, all SQL and DDL statements are printed out before being + run. Useful for debugging. + + :param log_responses: + If ``True``, the raw response from each query is printed out. Useful + for debugging. + + :param extra_nodes: + If you have additional database nodes (e.g. read replicas) for the + server, you can specify them here. It's a mapping of a memorable name + to a ``PostgresEngine`` instance. For example:: + + DB = PostgresEngine( + config={'database': 'main_db'}, + extra_nodes={ + 'read_replica_1': PostgresEngine( + config={ + 'database': 'main_db', + host: 'read_replicate.my_db.com' + }, + extensions=() + ) + } + ) - """ # noqa: E501 + Note how we set ``extensions=()``, because it's a read only database. - __slots__ = ("config", "extensions", "pool", "transaction_connection") + When executing a query, you can specify one of these nodes instead + of the main database. For example:: - engine_type = "postgres" - min_version_number = 9.6 + >>> await MyTable.select().run(node="read_replica_1") + + """ # noqa: E501 + + __slots__ = ( + "config", + "extensions", + "extra_nodes", + "pool", + ) def __init__( self, - config: t.Dict[str, t.Any], - extensions: t.Sequence[str] = ["uuid-ossp"], + config: dict[str, Any], + extensions: Sequence[str] = tuple(), log_queries: bool = False, + log_responses: bool = False, + extra_nodes: Optional[Mapping[str, PostgresEngine]] = None, ) -> None: + if extra_nodes is None: + extra_nodes = {} + self.config = config self.extensions = extensions self.log_queries = log_queries - self.pool: t.Optional[Pool] = None + self.log_responses = log_responses + self.extra_nodes = extra_nodes + self.pool: Optional[Pool] = None database_name = config.get("database", "Unknown") - self.transaction_connection = contextvars.ContextVar( - f"pg_transaction_connection_{database_name}", default=None + self.current_transaction = contextvars.ContextVar( + f"pg_current_transaction_{database_name}", default=None + ) + super().__init__( + engine_type="postgres", + log_queries=log_queries, + log_responses=log_responses, + min_version_number=10, ) - super().__init__() @staticmethod def _parse_raw_version_string(version_string: str) -> float: @@ -256,22 +390,20 @@ def _parse_raw_version_string(version_string: str) -> float: """ version_segment = version_string.split(" ")[0] major, minor = version_segment.split(".")[:2] - version = float(f"{major}.{minor}") - return version + return float(f"{major}.{minor}") async def get_version(self) -> float: """ Returns the version of Postgres being run. """ try: - response: t.Sequence[t.Dict] = await self._run_in_new_connection( + response: Sequence[dict] = await self._run_in_new_connection( "SHOW server_version" ) except ConnectionRefusedError as exception: # Suppressing the exception, otherwise importing piccolo_conf.py # containing an engine will raise an ImportError. - colored_warning("Unable to connect to database") - print(exception) + colored_warning(f"Unable to connect to database - {exception}") return 0.0 else: version_string = response[0]["server_version"] @@ -279,6 +411,9 @@ async def get_version(self) -> float: version_string=version_string ) + def get_version_sync(self) -> float: + return run_sync(self.get_version()) + async def prep_database(self): for extension in self.extensions: try: @@ -286,20 +421,18 @@ async def prep_database(self): f'CREATE EXTENSION IF NOT EXISTS "{extension}"', ) except asyncpg.exceptions.InsufficientPrivilegeError: - print( - colored_string( - f"=> Unable to create {extension} extension - some " - "functionality may not behave as expected. Make sure " - "your database user has permission to create " - "extensions, or add it manually using " - f'`CREATE EXTENSION "{extension}";`', - level=Level.medium, - ) + colored_warning( + f"=> Unable to create {extension} extension - some " + "functionality may not behave as expected. Make sure " + "your database user has permission to create " + "extensions, or add it manually using " + f'`CREATE EXTENSION "{extension}";`', + level=Level.medium, ) ########################################################################### # These typos existed in the codebase for a while, so leaving these proxy - # methods for now to ensure backwards compatility. + # methods for now to ensure backwards compatibility. async def start_connnection_pool(self, **kwargs) -> None: colored_warning( @@ -347,15 +480,34 @@ async def get_new_connection(self) -> Connection: ########################################################################### - async def batch(self, query: Query, batch_size: int = 100) -> AsyncBatch: - connection = await self.get_new_connection() + async def batch( + self, + query: Query, + batch_size: int = 100, + node: Optional[str] = None, + ) -> AsyncBatch: + """ + :param query: + The database query to run. + :param batch_size: + How many rows to fetch on each iteration. + :param node: + Which node to run the query on (see ``extra_nodes``). If not + specified, it runs on the main Postgres node. + """ + engine: Any = self.extra_nodes.get(node) if node else self + connection = await engine.get_new_connection() return AsyncBatch( connection=connection, query=query, batch_size=batch_size ) ########################################################################### - async def _run_in_pool(self, query: str, args: t.Sequence[t.Any] = []): + async def _run_in_pool( + self, query: str, args: Optional[Sequence[Any]] = None + ): + if args is None: + args = [] if not self.pool: raise ValueError("A pool isn't currently running.") @@ -365,10 +517,18 @@ async def _run_in_pool(self, query: str, args: t.Sequence[t.Any] = []): return response async def _run_in_new_connection( - self, query: str, args: t.Sequence[t.Any] = [] + self, query: str, args: Optional[Sequence[Any]] = None ): + if args is None: + args = [] connection = await self.get_new_connection() - results = await connection.fetch(query, *args) + + try: + results = await connection.fetch(query, *args) + except asyncpg.exceptions.PostgresError as exception: + await connection.close() + raise exception + await connection.close() return results @@ -379,20 +539,56 @@ async def run_querystring( engine_type=self.engine_type ) + query_id = self.get_query_id() + if self.log_queries: - print(querystring) + self.print_query(query_id=query_id, query=querystring.__str__()) # If running inside a transaction: - connection = self.transaction_connection.get() - if connection: - return await connection.fetch(query, *query_args) + current_transaction = self.current_transaction.get() + if current_transaction: + response = await current_transaction.connection.fetch( + query, *query_args + ) elif in_pool and self.pool: - return await self._run_in_pool(query, query_args) + response = await self._run_in_pool(query, query_args) else: - return await self._run_in_new_connection(query, query_args) + response = await self._run_in_new_connection(query, query_args) + + if self.log_responses: + self.print_response(query_id=query_id, response=response) + + return response + + async def run_ddl(self, ddl: str, in_pool: bool = True): + query_id = self.get_query_id() + + if self.log_queries: + self.print_query(query_id=query_id, query=ddl) + + # If running inside a transaction: + current_transaction = self.current_transaction.get() + if current_transaction: + response = await current_transaction.connection.fetch(ddl) + elif in_pool and self.pool: + response = await self._run_in_pool(ddl) + else: + response = await self._run_in_new_connection(ddl) + + if self.log_responses: + self.print_response(query_id=query_id, response=response) + + return response + + def transform_response_to_dicts(self, results) -> list[dict]: + """ + asyncpg returns a special Record object, so we need to convert it to + a dict. + """ + return [dict(i) for i in results] def atomic(self) -> Atomic: return Atomic(engine=self) - def transaction(self) -> Transaction: - return Transaction(engine=self) + def transaction(self, allow_nested: bool = True) -> PostgresTransaction: + return PostgresTransaction(engine=self, allow_nested=allow_nested) diff --git a/piccolo/engine/sqlite.py b/piccolo/engine/sqlite.py index 8b70e825f..86de912f3 100644 --- a/piccolo/engine/sqlite.py +++ b/piccolo/engine/sqlite.py @@ -2,16 +2,27 @@ import contextvars import datetime +import enum import os import sqlite3 -import typing as t import uuid +from collections.abc import Callable from dataclasses import dataclass from decimal import Decimal - -from piccolo.engine.base import Batch, Engine +from functools import partial, wraps +from typing import TYPE_CHECKING, Any, Optional, Union + +from typing_extensions import Self + +from piccolo.engine.base import ( + BaseAtomic, + BaseBatch, + BaseTransaction, + Engine, + validate_savepoint_name, +) from piccolo.engine.exceptions import TransactionError -from piccolo.query.base import Query +from piccolo.query.base import DDL, Query from piccolo.querystring import QueryString from piccolo.utils.encoding import dump_json, load_json from piccolo.utils.lazy_loader import LazyLoader @@ -20,7 +31,7 @@ aiosqlite = LazyLoader("aiosqlite", globals(), "aiosqlite") -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from aiosqlite import Connection, Cursor # type: ignore from piccolo.table import Table @@ -34,14 +45,14 @@ # In -def convert_numeric_in(value): +def convert_numeric_in(value: Decimal) -> float: """ Convert any Decimal values into floats. """ return float(value) -def convert_uuid_in(value) -> str: +def convert_uuid_in(value: uuid.UUID) -> str: """ Converts the UUID value being passed into sqlite. """ @@ -55,7 +66,7 @@ def convert_time_in(value: datetime.time) -> str: return value.isoformat() -def convert_date_in(value: datetime.date): +def convert_date_in(value: datetime.date) -> str: """ Converts the date value being passed into sqlite. """ @@ -73,120 +84,259 @@ def convert_datetime_in(value: datetime.datetime) -> str: return str(value) -def convert_timedelta_in(value: datetime.timedelta): +def convert_timedelta_in(value: datetime.timedelta) -> float: """ Converts the timedelta value being passed into sqlite. """ return value.total_seconds() -def convert_array_in(value: list): +def convert_array_in(value: list) -> str: """ - Converts a list value into a string. + Converts a list value into a string (it handles nested lists, and type like + dateime/ time / date which aren't usually JSON serialisable.). + """ - if len(value) > 0: - if type(value[0]) not in [str, int, float]: - raise ValueError("Can only serialise str, int and float.") - return dump_json(value) + def serialise(data: list): + output = [] + + for item in data: + if isinstance(item, list): + output.append(serialise(item)) + elif isinstance( + item, (datetime.datetime, datetime.time, datetime.date) + ): + if adapter := ADAPTERS.get(type(item)): + output.append(adapter(item)) + else: + raise ValueError("The adapter wasn't found.") + elif item is None or isinstance(item, (str, int, float, list)): + # We can safely JSON serialise these. + output.append(item) + else: + raise ValueError("We can't currently serialise this value.") + return output + + return dump_json(serialise(value)) + + +############################################################################### + +# Register adapters + +ADAPTERS: dict[type, Callable[[Any], Any]] = { + Decimal: convert_numeric_in, + uuid.UUID: convert_uuid_in, + datetime.time: convert_time_in, + datetime.date: convert_date_in, + datetime.datetime: convert_datetime_in, + datetime.timedelta: convert_timedelta_in, + list: convert_array_in, +} + +for value_type, adapter in ADAPTERS.items(): + sqlite3.register_adapter(value_type, adapter) + +############################################################################### # Out -def convert_numeric_out(value: bytes) -> Decimal: +def decode_to_string(converter: Callable[[str], Any]): + """ + This means we can use our converters with string and bytes. They are + passed bytes when used directly via SQLite, and are passed strings when + used by the array converters. + """ + + @wraps(converter) + def wrapper(value: Union[str, bytes]) -> Any: + if isinstance(value, bytes): + return converter(value.decode("utf8")) + elif isinstance(value, str): + return converter(value) + else: + raise ValueError("Unsupported type") + + return wrapper + + +@decode_to_string +def convert_numeric_out(value: str) -> Decimal: """ Convert float values into Decimals. """ - return Decimal(value.decode("ascii")) + return Decimal(value) -def convert_int_out(value: bytes) -> int: +@decode_to_string +def convert_int_out(value: str) -> int: """ - Make sure Integer values are actually of type int. + Make sure INTEGER values are actually of type ``int``. + + SQLite doesn't enforce that the values in INTEGER columns are actually + integers - they could be strings ('hello'), or floats (1.0). + + There's not much we can do if the value is something like 'hello' - a + ``ValueError`` is appropriate in this situation. + + For a value like ``1.0``, it seems reasonable to handle this, and return a + value of ``1``. + """ - return int(float(value)) + # We used to use int(float(value)), but it was incorrect, because float has + # limited precision for large numbers. + return int(Decimal(value)) -def convert_uuid_out(value: bytes) -> uuid.UUID: +@decode_to_string +def convert_uuid_out(value: str) -> uuid.UUID: """ If the value is a uuid, convert it to a UUID instance. """ - return uuid.UUID(value.decode("utf8")) + return uuid.UUID(value) -def convert_date_out(value: bytes) -> datetime.date: - return datetime.date.fromisoformat(value.decode("utf8")) +@decode_to_string +def convert_date_out(value: str) -> datetime.date: + return datetime.date.fromisoformat(value) -def convert_time_out(value: bytes) -> datetime.time: +@decode_to_string +def convert_time_out(value: str) -> datetime.time: """ If the value is a time, convert it to a UUID instance. """ - return datetime.time.fromisoformat(value.decode("utf8")) + return datetime.time.fromisoformat(value) -def convert_seconds_out(value: bytes) -> datetime.timedelta: +@decode_to_string +def convert_seconds_out(value: str) -> datetime.timedelta: """ If the value is from a seconds column, convert it to a timedelta instance. """ - return datetime.timedelta(seconds=float(value.decode("utf8"))) + return datetime.timedelta(seconds=float(value)) -def convert_boolean_out(value: bytes) -> bool: +@decode_to_string +def convert_boolean_out(value: str) -> bool: """ If the value is from a boolean column, convert it to a bool value. """ - _value = value.decode("utf8") - return _value == "1" + return value == "1" -def convert_timestamptz_out(value: bytes) -> datetime.datetime: +@decode_to_string +def convert_timestamp_out(value: str) -> datetime.datetime: """ - If the value is from a timstamptz column, convert it to a datetime value, + If the value is from a timestamp column, convert it to a datetime value. + """ + return datetime.datetime.fromisoformat(value) + + +@decode_to_string +def convert_timestamptz_out(value: str) -> datetime.datetime: + """ + If the value is from a timestamptz column, convert it to a datetime value, with a timezone of UTC. """ - return datetime.datetime.fromisoformat(value.decode("utf8")) + return datetime.datetime.fromisoformat(value).replace( + tzinfo=datetime.timezone.utc + ) -def convert_array_out(value: bytes) -> t.List: +@decode_to_string +def convert_array_out(value: str) -> list: """ If the value if from an array column, deserialise the string back into a list. """ - return load_json(value.decode("utf8")) + return load_json(value) -sqlite3.register_converter("Numeric", convert_numeric_out) -sqlite3.register_converter("Integer", convert_int_out) -sqlite3.register_converter("UUID", convert_uuid_out) -sqlite3.register_converter("Date", convert_date_out) -sqlite3.register_converter("Time", convert_time_out) -sqlite3.register_converter("Seconds", convert_seconds_out) -sqlite3.register_converter("Boolean", convert_boolean_out) -sqlite3.register_converter("Timestamptz", convert_timestamptz_out) -sqlite3.register_converter("Array", convert_array_out) +def convert_complex_array_out(value: bytes, converter: Callable): + """ + This is used to handle arrays of things like timestamps, which we can't + just load from JSON without doing additional work to convert the elements + back into Python objects. + """ + parsed = load_json(value.decode("utf8")) + + def convert_list(list_value: list): + output = [] + + for value in list_value: + if isinstance(value, list): + # For nested arrays + output.append(convert_list(value)) + elif isinstance(value, str): + output.append(converter(value)) + else: + output.append(value) + + return output + + if isinstance(parsed, list): + return convert_list(parsed) + else: + return parsed + + +@decode_to_string +def convert_M2M_out(value: str) -> list: + return value.split(",") -sqlite3.register_adapter(Decimal, convert_numeric_in) -sqlite3.register_adapter(uuid.UUID, convert_uuid_in) -sqlite3.register_adapter(datetime.time, convert_time_in) -sqlite3.register_adapter(datetime.date, convert_date_in) -sqlite3.register_adapter(datetime.datetime, convert_datetime_in) -sqlite3.register_adapter(datetime.timedelta, convert_timedelta_in) -sqlite3.register_adapter(list, convert_array_in) ############################################################################### +# Register the basic converters + +CONVERTERS = { + "NUMERIC": convert_numeric_out, + "INTEGER": convert_int_out, + "UUID": convert_uuid_out, + "DATE": convert_date_out, + "TIME": convert_time_out, + "SECONDS": convert_seconds_out, + "BOOLEAN": convert_boolean_out, + "TIMESTAMP": convert_timestamp_out, + "TIMESTAMPTZ": convert_timestamptz_out, + "M2M": convert_M2M_out, +} + +for column_name, converter in CONVERTERS.items(): + sqlite3.register_converter(column_name, converter) +############################################################################### +# Register the array converters + +# The ARRAY column type handles values which can be easily serialised to and +# from JSON. +sqlite3.register_converter("ARRAY", convert_array_out) + +# We have special column types for arrays of timestamps etc, as simply loading +# the JSON isn't sufficient. +for column_name in ("TIMESTAMP", "TIMESTAMPTZ", "DATE", "TIME"): + sqlite3.register_converter( + f"ARRAY_{column_name}", + partial( + convert_complex_array_out, + converter=CONVERTERS[column_name], + ), + ) -@dataclass -class AsyncBatch(Batch): +############################################################################### + +@dataclass +class AsyncBatch(BaseBatch): connection: Connection query: Query batch_size: int # Set internally - _cursor: t.Optional[Cursor] = None + _cursor: Optional[Cursor] = None @property def cursor(self) -> Cursor: @@ -194,20 +344,20 @@ def cursor(self) -> Cursor: raise ValueError("_cursor not set") return self._cursor - async def next(self) -> t.List[t.Dict]: + async def next(self) -> list[dict]: data = await self.cursor.fetchmany(self.batch_size) return await self.query._process_results(data) - def __aiter__(self): + def __aiter__(self: Self) -> Self: return self - async def __anext__(self): + async def __anext__(self) -> list[dict]: response = await self.next() if response == []: raise StopAsyncIteration() return response - async def __aenter__(self): + async def __aenter__(self: Self) -> Self: querystring = self.query.querystrings[0] template, template_args = querystring.compile_string() @@ -215,7 +365,7 @@ async def __aenter__(self): return self async def __aexit__(self, exception_type, exception, traceback): - await self._cursor.close() + await self.cursor.close() await self.connection.close() return exception is not None @@ -223,7 +373,18 @@ async def __aexit__(self, exception_type, exception, traceback): ############################################################################### -class Atomic: +class TransactionType(enum.Enum): + """ + See the `SQLite `_ docs for + more info. + """ + + deferred = "DEFERRED" + immediate = "IMMEDIATE" + exclusive = "EXCLUSIVE" + + +class Atomic(BaseAtomic): """ Usage: @@ -235,80 +396,180 @@ class Atomic: await transaction.run() """ - __slots__ = ("engine", "queries") + __slots__ = ("engine", "queries", "transaction_type") - def __init__(self, engine: SQLiteEngine): + def __init__( + self, + engine: SQLiteEngine, + transaction_type: TransactionType = TransactionType.deferred, + ): self.engine = engine - self.queries: t.List[Query] = [] + self.transaction_type = transaction_type + self.queries: list[Union[Query, DDL]] = [] - def add(self, *query: Query): + def add(self, *query: Union[Query, DDL]): self.queries += list(query) async def run(self): - connection = await self.engine.get_connection() - await connection.execute("BEGIN") + from piccolo.query.methods.objects import Create, GetOrCreate try: - for query in self.queries: - for querystring in query.querystrings: - await connection.execute( - *querystring.compile_string( - engine_type=self.engine.engine_type - ) - ) - except Exception as exception: - await connection.execute("ROLLBACK") - await connection.close() + async with self.engine.transaction( + transaction_type=self.transaction_type + ): + for query in self.queries: + if isinstance(query, (Query, DDL, Create, GetOrCreate)): + await query.run() + else: + raise ValueError("Unrecognised query") self.queries = [] - raise exception - else: - await connection.execute("COMMIT") - await connection.close() + except Exception as exception: self.queries = [] + raise exception from exception def run_sync(self): return run_sync(self.run()) + def __await__(self): + return self.run().__await__() + ############################################################################### -class Transaction: +class Savepoint: + def __init__(self, name: str, transaction: SQLiteTransaction): + self.name = name + self.transaction = transaction + + async def rollback_to(self): + validate_savepoint_name(self.name) + await self.transaction.connection.execute( + f"ROLLBACK TO SAVEPOINT {self.name}" + ) + + async def release(self): + validate_savepoint_name(self.name) + await self.transaction.connection.execute( + f"RELEASE SAVEPOINT {self.name}" + ) + + +class SQLiteTransaction(BaseTransaction): """ Used for wrapping queries in a transaction, using a context manager. Currently it's async only. - Usage: + Usage:: - async with engine.transaction(): - # Run some queries: - await Band.select().run() + async with engine.transaction(): + # Run some queries: + await Band.select().run() """ - __slots__ = ("engine", "context", "connection") + __slots__ = ( + "engine", + "context", + "connection", + "transaction_type", + "allow_nested", + "_savepoint_id", + "_parent", + "_committed", + "_rolled_back", + ) - def __init__(self, engine: SQLiteEngine): + def __init__( + self, + engine: SQLiteEngine, + transaction_type: TransactionType = TransactionType.deferred, + allow_nested: bool = True, + ): + """ + :param transaction_type: + If your transaction just contains ``SELECT`` queries, then use + ``TransactionType.deferred``. This will give you the best + performance. When performing ``INSERT``, ``UPDATE``, ``DELETE`` + queries, we recommend using ``TransactionType.immediate`` to + avoid database locks. + """ self.engine = engine - if self.engine.transaction_connection.get(): - raise TransactionError( - "A transaction is already active - nested transactions aren't " - "currently supported." - ) + self.transaction_type = transaction_type + current_transaction = self.engine.current_transaction.get() - async def __aenter__(self): - self.connection = await self.engine.get_connection() - await self.connection.execute("BEGIN") - self.context = self.engine.transaction_connection.set(self.connection) + self._savepoint_id = 0 + self._parent = None + self._committed = False + self._rolled_back = False + + if current_transaction: + if allow_nested: + self._parent = current_transaction + else: + raise TransactionError( + "A transaction is already active - nested transactions " + "aren't allowed." + ) + + async def __aenter__(self) -> SQLiteTransaction: + if self._parent is not None: + return self._parent + + self.connection = await self.get_connection() + await self.begin() + self.context = self.engine.current_transaction.set(self) + return self + + async def get_connection(self): + return await self.engine.get_connection() + + async def begin(self): + await self.connection.execute(f"BEGIN {self.transaction_type.value}") + + async def commit(self): + await self.connection.execute("COMMIT") + self._committed = True + + async def rollback(self): + await self.connection.execute("ROLLBACK") + self._rolled_back = True + + async def rollback_to(self, savepoint_name: str): + """ + Used to rollback to a savepoint just using the name. + """ + await Savepoint(name=savepoint_name, transaction=self).rollback_to() + + ########################################################################### + + def get_savepoint_id(self) -> int: + self._savepoint_id += 1 + return self._savepoint_id + + async def savepoint(self, name: Optional[str] = None) -> Savepoint: + name = name or f"savepoint_{self.get_savepoint_id()}" + validate_savepoint_name(name) + await self.connection.execute(f"SAVEPOINT {name}") + return Savepoint(name=name, transaction=self) + + ########################################################################### + + async def __aexit__(self, exception_type, exception, traceback) -> bool: + if self._parent: + return exception is None - async def __aexit__(self, exception_type, exception, traceback): if exception: - await self.connection.execute("ROLLBACK") + # The user may have manually rolled it back. + if not self._rolled_back: + await self.rollback() else: - await self.connection.execute("COMMIT") + # The user may have manually committed it. + if not self._committed and not self._rolled_back: + await self.commit() await self.connection.close() - self.engine.transaction_connection.reset(self.context) + self.engine.current_transaction.reset(self.context) return exception is None @@ -316,49 +577,62 @@ async def __aexit__(self, exception_type, exception, traceback): ############################################################################### -def dict_factory(cursor, row) -> t.Dict: - d = {} - for idx, col in enumerate(cursor.description): - d[col[0]] = row[idx] - return d - - -class SQLiteEngine(Engine): - """ - Any connection kwargs are passed into the database adapter. - - See here for more info: - https://docs.python.org/3/library/sqlite3.html#sqlite3.connect +def dict_factory(cursor, row) -> dict: + return {col[0]: row[idx] for idx, col in enumerate(cursor.description)} - """ +class SQLiteEngine(Engine[SQLiteTransaction]): __slots__ = ("connection_kwargs",) - engine_type = "sqlite" - min_version_number = 3.25 - def __init__( self, path: str = "piccolo.sqlite", - detect_types=sqlite3.PARSE_DECLTYPES, - isolation_level=None, + log_queries: bool = False, + log_responses: bool = False, **connection_kwargs, ) -> None: - connection_kwargs.update( - { - "database": path, - "detect_types": detect_types, - "isolation_level": isolation_level, - } + """ + :param path: + A relative or absolute path to the the SQLite database file (it + will be created if it doesn't already exist). + :param log_queries: + If ``True``, all SQL and DDL statements are printed out before + being run. Useful for debugging. + :param log_responses: + If ``True``, the raw response from each query is printed out. + Useful for debugging. + :param connection_kwargs: + These are passed directly to the database adapter. We recommend + setting ``timeout`` if you expect your application to process a + large number of concurrent writes, to prevent queries timing out. + See Python's `sqlite3 docs `_ + for more info. + + """ # noqa: E501 + default_connection_kwargs = { + "database": path, + "detect_types": sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES, + "isolation_level": None, + } + + self.log_queries = log_queries + self.log_responses = log_responses + self.connection_kwargs = { + **default_connection_kwargs, + **connection_kwargs, + } + + self.current_transaction = contextvars.ContextVar( + f"sqlite_current_transaction_{path}", default=None ) - self.connection_kwargs = connection_kwargs - self.transaction_connection = contextvars.ContextVar( - f"sqlite_transaction_connection_{path}", default=None + super().__init__( + engine_type="sqlite", + min_version_number=3.25, + log_queries=log_queries, + log_responses=log_responses, ) - super().__init__() - @property def path(self): return self.connection_kwargs["database"] @@ -368,9 +642,9 @@ def path(self, value: str): self.connection_kwargs["database"] = value async def get_version(self) -> float: - """ - Warn if the version of SQLite isn't supported. - """ + return self.get_version_sync() + + def get_version_sync(self) -> float: major, minor, _ = sqlite3.sqlite_version_info return float(f"{major}.{minor}") @@ -381,35 +655,35 @@ async def prep_database(self): def remove_db_file(self): """ - Use with caution - removes the sqlite file. Useful for testing + Use with caution - removes the SQLite file. Useful for testing purposes. """ if os.path.exists(self.path): os.unlink(self.path) - def create_db(self, migrate=False): + def create_db_file(self): """ - Create the database file, with the option to run migrations. Useful - for testing purposes. + Create the database file. Useful for testing purposes. """ - if not os.path.exists(self.path): - with open(self.path, "w"): - pass - else: + if os.path.exists(self.path): raise Exception(f"Database at {self.path} already exists") - if migrate: - # Commented out for now, as migrations for SQLite aren't as - # well supported as Postgres. - # from piccolo.commands.migration.forwards import ( - # ForwardsMigrationManager, - # ) - - # ForwardsMigrationManager().run() + with open(self.path, "w"): pass ########################################################################### - async def batch(self, query: Query, batch_size: int = 100) -> AsyncBatch: + async def batch( + self, query: Query, batch_size: int = 100, node: Optional[str] = None + ) -> AsyncBatch: + """ + :param query: + The database query to run. + :param batch_size: + How many rows to fetch on each iteration. + :param node: + This is ignored currently, as SQLite runs off a single node. The + value is here so the API is consistent with Postgres. + """ connection = await self.get_connection() return AsyncBatch( connection=connection, query=query, batch_size=batch_size @@ -425,27 +699,28 @@ async def get_connection(self) -> Connection: ########################################################################### - async def _get_inserted_pk(self, cursor, table: t.Type[Table]) -> t.Any: + async def _get_inserted_pk(self, cursor, table: type[Table]) -> Any: """ If the `pk` column is a non-integer then `ROWID` and `pk` will return different types. Need to query by `lastrowid` to get `pk`s in SQLite prior to 3.35.0. """ - # TODO: Add RETURNING clause for sqlite > 3.35.0 await cursor.execute( - f"SELECT {table._meta.primary_key._meta.name} FROM " + f"SELECT {table._meta.primary_key._meta.db_column_name} FROM " f"{table._meta.tablename} WHERE ROWID = {cursor.lastrowid}" ) response = await cursor.fetchone() - return response[table._meta.primary_key._meta.name] + return response[table._meta.primary_key._meta.db_column_name] async def _run_in_new_connection( self, query: str, - args: t.List[t.Any] = [], + args: Optional[list[Any]] = None, query_type: str = "generic", - table: t.Optional[t.Type[Table]] = None, + table: Optional[type[Table]] = None, ): + if args is None: + args = [] async with aiosqlite.connect(**self.connection_kwargs) as connection: await connection.execute("PRAGMA foreign_keys = 1") @@ -453,10 +728,12 @@ async def _run_in_new_connection( async with connection.execute(query, args) as cursor: await connection.commit() - if query_type == "insert": + if query_type == "insert" and self.get_version_sync() < 3.35: + # We can't use the RETURNING clause on older versions + # of SQLite. assert table is not None pk = await self._get_inserted_pk(cursor, table) - return [{table._meta.primary_key._meta.name: pk}] + return [{table._meta.primary_key._meta.db_column_name: pk}] else: return await cursor.fetchall() @@ -464,23 +741,27 @@ async def _run_in_existing_connection( self, connection, query: str, - args: t.List[t.Any] = [], + args: Optional[list[Any]] = None, query_type: str = "generic", - table: t.Optional[t.Type[Table]] = None, + table: Optional[type[Table]] = None, ): """ This is used when a transaction is currently active. """ + if args is None: + args = [] await connection.execute("PRAGMA foreign_keys = 1") connection.row_factory = dict_factory async with connection.execute(query, args) as cursor: response = await cursor.fetchall() - if query_type == "insert": + if query_type == "insert" and self.get_version_sync() < 3.35: + # We can't use the RETURNING clause on older versions + # of SQLite. assert table is not None pk = await self._get_inserted_pk(cursor, table) - return [{table._meta.primary_key._meta.name: pk}] + return [{table._meta.primary_key._meta.db_column_name: pk}] else: return response @@ -491,30 +772,80 @@ async def run_querystring( Connection pools aren't currently supported - the argument is there for consistency with other engines. """ + query_id = self.get_query_id() + + if self.log_queries: + self.print_query(query_id=query_id, query=querystring.__str__()) + query, query_args = querystring.compile_string( engine_type=self.engine_type ) # If running inside a transaction: - connection = self.transaction_connection.get() - if connection: - return await self._run_in_existing_connection( - connection=connection, + current_transaction = self.current_transaction.get() + if current_transaction: + response = await self._run_in_existing_connection( + connection=current_transaction.connection, + query=query, + args=query_args, + query_type=querystring.query_type, + table=querystring.table, + ) + else: + response = await self._run_in_new_connection( query=query, args=query_args, query_type=querystring.query_type, table=querystring.table, ) - return await self._run_in_new_connection( - query=query, - args=query_args, - query_type=querystring.query_type, - table=querystring.table, - ) + if self.log_responses: + self.print_response(query_id=query_id, response=response) + + return response + + async def run_ddl(self, ddl: str, in_pool: bool = False): + """ + Connection pools aren't currently supported - the argument is there + for consistency with other engines. + """ + query_id = self.get_query_id() + + if self.log_queries: + self.print_query(query_id=query_id, query=ddl) + + # If running inside a transaction: + current_transaction = self.current_transaction.get() + if current_transaction: + response = await self._run_in_existing_connection( + connection=current_transaction.connection, + query=ddl, + ) + else: + response = await self._run_in_new_connection( + query=ddl, + ) + + if self.log_responses: + self.print_response(query_id=query_id, response=response) + + return response - def atomic(self) -> Atomic: - return Atomic(engine=self) + def atomic( + self, transaction_type: TransactionType = TransactionType.deferred + ) -> Atomic: + return Atomic(engine=self, transaction_type=transaction_type) - def transaction(self) -> Transaction: - return Transaction(engine=self) + def transaction( + self, + transaction_type: TransactionType = TransactionType.deferred, + allow_nested: bool = True, + ) -> SQLiteTransaction: + """ + Create a new database transaction. See :class:`Transaction`. + """ + return SQLiteTransaction( + engine=self, + transaction_type=transaction_type, + allow_nested=allow_nested, + ) diff --git a/piccolo/main.py b/piccolo/main.py index 01596f8c7..6b6a5dc27 100644 --- a/piccolo/main.py +++ b/piccolo/main.py @@ -1,7 +1,7 @@ import os import sys -from targ import CLI # type: ignore +from targ import CLI try: import uvloop # type: ignore @@ -12,17 +12,20 @@ from piccolo.apps.app.piccolo_app import APP_CONFIG as app_config from piccolo.apps.asgi.piccolo_app import APP_CONFIG as asgi_config +from piccolo.apps.fixtures.piccolo_app import APP_CONFIG as fixtures_config from piccolo.apps.meta.piccolo_app import APP_CONFIG as meta_config from piccolo.apps.migrations.commands.check import CheckMigrationManager from piccolo.apps.migrations.piccolo_app import APP_CONFIG as migrations_config from piccolo.apps.playground.piccolo_app import APP_CONFIG as playground_config from piccolo.apps.project.piccolo_app import APP_CONFIG as project_config +from piccolo.apps.schema.piccolo_app import APP_CONFIG as schema_config from piccolo.apps.shell.piccolo_app import APP_CONFIG as shell_config from piccolo.apps.sql_shell.piccolo_app import APP_CONFIG as sql_shell_config +from piccolo.apps.tester.piccolo_app import APP_CONFIG as tester_config from piccolo.apps.user.piccolo_app import APP_CONFIG as user_config from piccolo.conf.apps import AppRegistry, Finder from piccolo.utils.sync import run_sync -from piccolo.utils.warnings import Level, colored_string +from piccolo.utils.warnings import Level, colored_warning DIAGNOSE_FLAG = "--diagnose" @@ -31,7 +34,7 @@ def get_diagnose_flag() -> bool: return DIAGNOSE_FLAG in sys.argv -def main(): +def main() -> None: """ The entrypoint to the Piccolo CLI. """ @@ -58,17 +61,21 @@ def main(): for _app_config in [ app_config, asgi_config, + fixtures_config, meta_config, migrations_config, playground_config, project_config, + schema_config, shell_config, sql_shell_config, + tester_config, user_config, ]: - for command in _app_config.commands: + for command in _app_config.get_commands(): cli.register( command.callable, + command_name=command.command_name, group_name=_app_config.app_name, aliases=command.aliases, ) @@ -86,22 +93,27 @@ def main(): ) else: for app_name, _app_config in APP_REGISTRY.app_configs.items(): - for command in _app_config.commands: + for command in _app_config.get_commands(): if cli.command_exists( - group_name=app_name, command_name=command.callable.__name__ + group_name=app_name, + command_name=command.callable.__name__, ): # Skipping - already registered. continue + cli.register( command.callable, group_name=app_name, aliases=command.aliases, ) - if "migrations" not in sys.argv: + if not {"playground", "migrations", "asgi"}.intersection( + set(sys.argv) + ): # Show a warning if any migrations haven't been run. # Don't run it if it looks like the user is running a migration - # command, as this information is redundant. + # command, or using the playground, as this information is + # redundant. try: havent_ran_count = run_sync( @@ -113,18 +125,17 @@ def main(): if havent_ran_count == 1 else f"{havent_ran_count} migrations haven't" ) - print( - colored_string( - message=( - "=> {} been run - the app " - "might not behave as expected.\n" - "To check which use:\n" - " piccolo migrations check\n" - "To run all migrations:\n" - " piccolo migrations forwards all\n" - ).format(message), - level=Level.high, - ) + + colored_warning( + message=( + "=> {} been run - the app " + "might not behave as expected.\n" + "To check which use:\n" + " piccolo migrations check\n" + "To run all migrations:\n" + " piccolo migrations forwards all\n" + ).format(message), + level=Level.high, ) except Exception: pass diff --git a/piccolo/query/__init__.py b/piccolo/query/__init__.py index 410fb0128..2fcc2df7e 100644 --- a/piccolo/query/__init__.py +++ b/piccolo/query/__init__.py @@ -1,7 +1,9 @@ +from piccolo.columns.combination import WhereRaw + from .base import Query +from .functions.aggregate import Avg, Max, Min, Sum from .methods import ( Alter, - Avg, Count, Create, CreateIndex, @@ -9,12 +11,35 @@ DropIndex, Exists, Insert, - Max, - Min, Objects, Raw, Select, - Sum, TableExists, Update, ) +from .methods.select import SelectRaw +from .mixins import OrderByRaw + +__all__ = [ + "Alter", + "Avg", + "Count", + "Create", + "CreateIndex", + "Delete", + "DropIndex", + "Exists", + "Insert", + "Max", + "Min", + "Objects", + "OrderByRaw", + "Query", + "Raw", + "Select", + "SelectRaw", + "Sum", + "TableExists", + "Update", + "WhereRaw", +] diff --git a/piccolo/query/base.py b/piccolo/query/base.py index 5ec36f2de..d45d885dc 100644 --- a/piccolo/query/base.py +++ b/piccolo/query/base.py @@ -1,16 +1,19 @@ from __future__ import annotations -import itertools -import typing as t +from collections.abc import Generator, Sequence from time import time +from typing import TYPE_CHECKING, Any, Generic, Optional, Union, cast from piccolo.columns.column_types import JSON, JSONB +from piccolo.custom_types import QueryResponseType, TableInstance from piccolo.query.mixins import ColumnsDelegate +from piccolo.query.operators.json import JSONQueryString from piccolo.querystring import QueryString -from piccolo.utils.encoding import dump_json, load_json +from piccolo.utils.encoding import load_json +from piccolo.utils.objects import make_nested_object from piccolo.utils.sync import run_sync -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from piccolo.query.mixins import OutputDelegate from piccolo.table import Table # noqa @@ -24,14 +27,13 @@ def __exit__(self, exception_type, exception, traceback): print(f"Duration: {self.end - self.start}s") -class Query: - +class Query(Generic[TableInstance, QueryResponseType]): __slots__ = ("table", "_frozen_querystrings") def __init__( self, - table: t.Type[Table], - frozen_querystrings: t.Optional[t.Sequence[QueryString]] = None, + table: type[TableInstance], + frozen_querystrings: Optional[Sequence[QueryString]] = None, ): self.table = table self._frozen_querystrings = frozen_querystrings @@ -44,47 +46,46 @@ def engine_type(self) -> str: else: raise ValueError("Engine isn't defined.") - async def _process_results(self, results): # noqa: C901 - if results: - keys = results[0].keys() - keys = [i.replace("$", ".") for i in keys] - raw = [dict(zip(keys, i.values())) for i in results] - else: - raw = [] + async def _process_results(self, results) -> QueryResponseType: + raw = ( + self.table._meta.db.transform_response_to_dicts(results) + if results + else [] + ) - if hasattr(self, "run_callback"): - self.run_callback(raw) + if hasattr(self, "_raw_response_callback"): + self._raw_response_callback(raw) - output: t.Optional[OutputDelegate] = getattr( + output: Optional[OutputDelegate] = getattr( self, "output_delegate", None ) ####################################################################### if output and output._output.load_json: - columns_delegate: t.Optional[ColumnsDelegate] = getattr( + columns_delegate: Optional[ColumnsDelegate] = getattr( self, "columns_delegate", None ) + json_column_names: list[str] = [] + if columns_delegate is not None: - json_columns = [ - i - for i in columns_delegate.selected_columns - if isinstance(i, (JSON, JSONB)) - ] + json_columns: list[Union[JSON, JSONB]] = [] + + for column in columns_delegate.selected_columns: + if isinstance(column, (JSON, JSONB)): + json_columns.append(column) + elif isinstance(column, JSONQueryString): + if alias := column._alias: + json_column_names.append(alias) else: json_columns = self.table._meta.json_columns - json_column_names = [] for column in json_columns: - if column.alias is not None: - json_column_names.append(column.alias) + if column._alias is not None: + json_column_names.append(column._alias) elif len(column._meta.call_chain) > 0: - json_column_names.append( - column.get_select_string( - engine_type=column._meta.engine_type - ) - ) + json_column_names.append(column._meta.get_default_alias()) else: json_column_names.append(column._meta.name) @@ -106,34 +107,21 @@ async def _process_results(self, results): # noqa: C901 if output: if output._output.as_objects: - # When using .first() we get a single row, not a list - # of rows. - if type(raw) is list: - raw = [ - self.table(**columns, exists_in_db=True) - for columns in raw - ] - elif raw is None: - pass + if output._output.nested: + return cast( + QueryResponseType, + [make_nested_object(row, self.table) for row in raw], + ) else: - raw = self.table(**raw, exists_in_db=True) - elif type(raw) is list: - if output._output.as_list: - if len(raw) == 0: - return [] - else: - if len(raw[0].keys()) != 1: - raise ValueError( - "Each row returned more than one value" - ) - else: - raw = list( - itertools.chain(*[j.values() for j in raw]) - ) - if output._output.as_json: - raw = dump_json(raw) - - return raw + return cast( + QueryResponseType, + [ + self.table(**columns, _exists_in_db=True) + for columns in raw + ], + ) + + return cast(QueryResponseType, raw) def _validate(self): """ @@ -143,51 +131,94 @@ def _validate(self): """ pass - def __await__(self): + def __await__(self) -> Generator[None, None, QueryResponseType]: """ If the user doesn't explicity call .run(), proxy to it as a convenience. """ return self.run().__await__() - async def run(self, in_pool=True): + async def _run( + self, node: Optional[str] = None, in_pool: bool = True + ) -> QueryResponseType: + """ + Run the query on the database. + + :param node: + If specified, run this query against another database node. Only + available in Postgres. See :class:`PostgresEngine `. + :param in_pool: + Whether to run this in a connection pool if one is available. This + is mostly just for debugging - use a connection pool where + possible. + + """ # noqa: E501 self._validate() engine = self.table._meta.db + if not engine: raise ValueError( f"Table {self.table._meta.tablename} has no db defined in " "_meta" ) - if len(self.querystrings) == 1: + if node is not None: + from piccolo.engine.postgres import PostgresEngine + + if isinstance(engine, PostgresEngine): + engine = engine.extra_nodes[node] + + querystrings = self.querystrings + + if len(querystrings) == 1: results = await engine.run_querystring( - self.querystrings[0], in_pool=in_pool + querystrings[0], in_pool=in_pool ) return await self._process_results(results) else: responses = [] - # TODO - run in a transaction - for querystring in self.querystrings: + for querystring in querystrings: results = await engine.run_querystring( querystring, in_pool=in_pool ) - responses.append(await self._process_results(results)) - return responses + processed_results = await self._process_results(results) - def run_sync(self, timed=False, *args, **kwargs): + responses.append(processed_results) + return cast(QueryResponseType, responses) + + async def run( + self, node: Optional[str] = None, in_pool: bool = True + ) -> QueryResponseType: + return await self._run(node=node, in_pool=in_pool) + + def run_sync( + self, + node: Optional[str] = None, + timed: bool = False, + in_pool: bool = False, + ) -> QueryResponseType: """ A convenience method for running the coroutine synchronously. + + :param timed: + If ``True``, the time taken to run the query is printed out. Useful + for debugging. + :param in_pool: + Whether to run this in a connection pool if one is available. Set + to ``False`` by default, because if an app uses ``run`` and + ``run_sync`` in the same app, it can cause errors. See + `issue 505 `_. + """ - coroutine = self.run(*args, **kwargs, in_pool=False) + coroutine = self.run(node=node, in_pool=in_pool) - if timed: - with Timer(): - return run_sync(coroutine) - else: + if not timed: + return run_sync(coroutine) + with Timer(): return run_sync(coroutine) - async def response_handler(self, response): + async def response_handler(self, response: list) -> Any: """ Subclasses can override this to modify the raw response returned by the database driver. @@ -197,19 +228,23 @@ async def response_handler(self, response): ########################################################################### @property - def sqlite_querystrings(self) -> t.Sequence[QueryString]: + def sqlite_querystrings(self) -> Sequence[QueryString]: raise NotImplementedError @property - def postgres_querystrings(self) -> t.Sequence[QueryString]: + def postgres_querystrings(self) -> Sequence[QueryString]: raise NotImplementedError @property - def default_querystrings(self) -> t.Sequence[QueryString]: + def cockroach_querystrings(self) -> Sequence[QueryString]: raise NotImplementedError @property - def querystrings(self) -> t.Sequence[QueryString]: + def default_querystrings(self) -> Sequence[QueryString]: + raise NotImplementedError + + @property + def querystrings(self) -> Sequence[QueryString]: """ Calls the correct underlying method, depending on the current engine. """ @@ -227,6 +262,11 @@ def querystrings(self) -> t.Sequence[QueryString]: return self.sqlite_querystrings except NotImplementedError: return self.default_querystrings + elif engine_type == "cockroach": + try: + return self.cockroach_querystrings + except NotImplementedError: + return self.default_querystrings else: raise Exception( f"No querystring found for the {engine_type} engine." @@ -255,7 +295,7 @@ def freeze(self) -> FrozenQuery: # In the corresponding view/endpoint of whichever web framework # you're using: async def top_bands(self, request): - return await TOP_BANDS.run() + return await TOP_BANDS It means that Piccolo doesn't have to work as hard each time the query is run to generate the corresponding SQL - some of it is cached. If the @@ -277,7 +317,7 @@ async def top_bands(self, request): # Copy the query, so we don't store any references to the original. query = self.__class__( - table=self.table, frozen_querystrings=self.querystrings + table=self.table, frozen_querystrings=querystrings ) if hasattr(self, "limit_delegate"): @@ -296,6 +336,9 @@ def __str__(self) -> str: return "; ".join([i.__str__() for i in self.querystrings]) +############################################################################### + + class FrozenQuery: def __init__(self, query: Query): self.query = query @@ -317,3 +360,100 @@ def __getattr__(self, name: str): def __str__(self) -> str: return self.query.__str__() + + +############################################################################### + + +class DDL: + __slots__ = ("table",) + + def __init__(self, table: type[Table], **kwargs): + self.table = table + + @property + def engine_type(self) -> str: + engine = self.table._meta.db + if engine: + return engine.engine_type + else: + raise ValueError("Engine isn't defined.") + + @property + def sqlite_ddl(self) -> Sequence[str]: + raise NotImplementedError + + @property + def postgres_ddl(self) -> Sequence[str]: + raise NotImplementedError + + @property + def cockroach_ddl(self) -> Sequence[str]: + raise NotImplementedError + + @property + def default_ddl(self) -> Sequence[str]: + raise NotImplementedError + + @property + def ddl(self) -> Sequence[str]: + """ + Calls the correct underlying method, depending on the current engine. + """ + engine_type = self.engine_type + if engine_type == "postgres": + try: + return self.postgres_ddl + except NotImplementedError: + return self.default_ddl + elif engine_type == "sqlite": + try: + return self.sqlite_ddl + except NotImplementedError: + return self.default_ddl + elif engine_type == "cockroach": + try: + return self.cockroach_ddl + except NotImplementedError: + return self.default_ddl + else: + raise Exception( + f"No querystring found for the {engine_type} engine." + ) + + def __await__(self): + """ + If the user doesn't explicity call .run(), proxy to it as a + convenience. + """ + return self.run().__await__() + + async def run(self, in_pool=True): + engine = self.table._meta.db + if not engine: + raise ValueError( + f"Table {self.table._meta.tablename} has no db defined in " + "_meta" + ) + + if len(self.ddl) == 1: + return await engine.run_ddl(self.ddl[0], in_pool=in_pool) + responses = [] + for ddl in self.ddl: + response = await engine.run_ddl(ddl, in_pool=in_pool) + responses.append(response) + return responses + + def run_sync(self, timed=False, *args, **kwargs): + """ + A convenience method for running the coroutine synchronously. + """ + coroutine = self.run(*args, **kwargs, in_pool=False) + + if not timed: + return run_sync(coroutine) + with Timer(): + return run_sync(coroutine) + + def __str__(self) -> str: + return self.ddl.__str__() diff --git a/piccolo/query/constraints.py b/piccolo/query/constraints.py new file mode 100644 index 000000000..a5859c100 --- /dev/null +++ b/piccolo/query/constraints.py @@ -0,0 +1,96 @@ +from dataclasses import dataclass +from typing import Optional + +from piccolo.columns import ForeignKey +from piccolo.columns.base import OnDelete, OnUpdate + + +async def get_fk_constraint_name(column: ForeignKey) -> Optional[str]: + """ + Checks what the foreign key constraint is called in the database. + """ + + table = column._meta.table + + if table._meta.db.engine_type == "sqlite": + # TODO - add the query for SQLite + raise ValueError("SQLite isn't currently supported.") + + schema = table._meta.schema or "public" + table_name = table._meta.tablename + column_name = column._meta.db_column_name + + constraints = await table.raw( + """ + SELECT + kcu.constraint_name AS fk_constraint_name + FROM + information_schema.referential_constraints AS rc + INNER JOIN + information_schema.key_column_usage AS kcu + ON kcu.constraint_catalog = rc.constraint_catalog + AND kcu.constraint_schema = rc.constraint_schema + AND kcu.constraint_name = rc.constraint_name + WHERE + kcu.table_schema = {} AND + kcu.table_name = {} AND + kcu.column_name = {} + """, + schema, + table_name, + column_name, + ) + + # if we change the column type from a non-FK column to + # an FK column, the previous column type has no FK constraints + # and we skip this to allow the migration to continue + return constraints[0]["fk_constraint_name"] if constraints else None + + +@dataclass +class ConstraintRules: + on_delete: OnDelete + on_update: OnUpdate + + +async def get_fk_constraint_rules(column: ForeignKey) -> ConstraintRules: + """ + Checks the constraint rules for this foreign key in the database. + """ + table = column._meta.table + + if table._meta.db.engine_type == "sqlite": + # TODO - add the query for SQLite + raise ValueError("SQLite isn't currently supported.") + + schema = table._meta.schema or "public" + table_name = table._meta.tablename + column_name = column._meta.db_column_name + + constraints = await table.raw( + """ + SELECT + kcu.constraint_name, + kcu.table_name, + kcu.column_name, + rc.update_rule, + rc.delete_rule + FROM + information_schema.key_column_usage AS kcu + INNER JOIN + information_schema.referential_constraints AS rc + ON kcu.constraint_name = rc.constraint_name + WHERE + kcu.table_schema = {} AND + kcu.table_name = {} AND + kcu.column_name = {} + """, + schema, + table_name, + column_name, + ) + + return ConstraintRules( + on_delete=OnDelete(constraints[0]["delete_rule"]), + on_update=OnUpdate(constraints[0]["update_rule"]), + ) diff --git a/piccolo/query/functions/__init__.py b/piccolo/query/functions/__init__.py new file mode 100644 index 000000000..9b233cca8 --- /dev/null +++ b/piccolo/query/functions/__init__.py @@ -0,0 +1,45 @@ +from .aggregate import Avg, Count, Max, Min, Sum +from .array import ( + ArrayAppend, + ArrayCat, + ArrayPrepend, + ArrayRemove, + ArrayReplace, +) +from .datetime import Day, Extract, Hour, Month, Second, Strftime, Year +from .math import Abs, Ceil, Floor, Round +from .string import Concat, Length, Lower, Ltrim, Reverse, Rtrim, Upper +from .type_conversion import Cast + +__all__ = ( + "Abs", + "Avg", + "Cast", + "Ceil", + "Concat", + "Count", + "Day", + "Extract", + "Extract", + "Floor", + "Hour", + "Length", + "Lower", + "Ltrim", + "Max", + "Min", + "Month", + "Reverse", + "Round", + "Rtrim", + "Second", + "Strftime", + "Sum", + "Upper", + "Year", + "ArrayAppend", + "ArrayCat", + "ArrayPrepend", + "ArrayRemove", + "ArrayReplace", +) diff --git a/piccolo/query/functions/aggregate.py b/piccolo/query/functions/aggregate.py new file mode 100644 index 000000000..499d56007 --- /dev/null +++ b/piccolo/query/functions/aggregate.py @@ -0,0 +1,180 @@ +from collections.abc import Sequence +from typing import Optional + +from piccolo.columns.base import Column +from piccolo.querystring import QueryString + +from .base import Function + + +class Avg(Function): + """ + ``AVG()`` SQL function. Column type must be numeric to run the query. + + .. code-block:: python + + await Band.select(Avg(Band.popularity)) + + # We can use an alias. These two are equivalent: + + await Band.select( + Avg(Band.popularity, alias="popularity_avg") + ) + + await Band.select( + Avg(Band.popularity).as_alias("popularity_avg") + ) + + """ + + function_name = "AVG" + + +class Count(QueryString): + """ + Used in ``Select`` queries, usually in conjunction with the ``group_by`` + clause:: + + >>> await Band.select( + ... Band.manager.name.as_alias('manager_name'), + ... Count(alias='band_count') + ... ).group_by(Band.manager) + [{'manager_name': 'Guido', 'count': 1}, ...] + + It can also be used without the ``group_by`` clause (though you may prefer + to the :meth:`Table.count ` method instead, as + it's more convenient):: + + >>> await Band.select(Count()) + [{'count': 3}] + + """ + + def __init__( + self, + column: Optional[Column] = None, + distinct: Optional[Sequence[Column]] = None, + alias: str = "count", + ): + """ + :param column: + If specified, the count is for non-null values in that column. + :param distinct: + If specified, the count is for distinct values in those columns. + :param alias: + The name of the value in the response:: + + # These two are equivalent: + + await Band.select( + Band.name, Count(alias="total") + ).group_by(Band.name) + + await Band.select( + Band.name, + Count().as_alias("total") + ).group_by(Band.name) + + """ + if distinct and column: + raise ValueError("Only specify `column` or `distinct`") + + if distinct: + engine_type = distinct[0]._meta.engine_type + if engine_type == "sqlite": + # SQLite doesn't allow us to specify multiple columns, so + # instead we concatenate the values. + column_names = " || ".join("{}" for _ in distinct) + else: + column_names = ", ".join("{}" for _ in distinct) + + return super().__init__( + f"COUNT(DISTINCT({column_names}))", *distinct, alias=alias + ) + else: + if column: + return super().__init__("COUNT({})", column, alias=alias) + else: + return super().__init__("COUNT(*)", alias=alias) + + +class Min(Function): + """ + ``MIN()`` SQL function. + + .. code-block:: python + + await Band.select(Min(Band.popularity)) + + # We can use an alias. These two are equivalent: + + await Band.select( + Min(Band.popularity, alias="popularity_min") + ) + + await Band.select( + Min(Band.popularity).as_alias("popularity_min") + ) + + """ + + function_name = "MIN" + + +class Max(Function): + """ + ``MAX()`` SQL function. + + .. code-block:: python + + await Band.select( + Max(Band.popularity) + ) + + # We can use an alias. These two are equivalent: + + await Band.select( + Max(Band.popularity, alias="popularity_max") + ) + + await Band.select( + Max(Band.popularity).as_alias("popularity_max") + ) + + """ + + function_name = "MAX" + + +class Sum(Function): + """ + ``SUM()`` SQL function. Column type must be numeric to run the query. + + .. code-block:: python + + await Band.select( + Sum(Band.popularity) + ) + + # We can use an alias. These two are equivalent: + + await Band.select( + Sum(Band.popularity, alias="popularity_sum") + ) + + await Band.select( + Sum(Band.popularity).as_alias("popularity_sum") + ) + + """ + + function_name = "SUM" + + +__all__ = ( + "Avg", + "Count", + "Min", + "Max", + "Sum", +) diff --git a/piccolo/query/functions/array.py b/piccolo/query/functions/array.py new file mode 100644 index 000000000..13e020929 --- /dev/null +++ b/piccolo/query/functions/array.py @@ -0,0 +1,151 @@ +from typing import Union + +from typing_extensions import TypeAlias + +from piccolo.columns.base import Column +from piccolo.querystring import QueryString + +ArrayType: TypeAlias = Union[Column, QueryString, list[object]] +ArrayItemType: TypeAlias = Union[Column, QueryString, object] + + +class ArrayQueryString(QueryString): + def __add__(self, array: ArrayType): + """ + QueryString will use the ``+`` operator by default for addition, but + for arrays we want to concatenate them instead. + """ + return ArrayCat(array_1=self, array_2=array) + + def __radd__(self, array: ArrayType): + return ArrayCat(array_1=array, array_2=self) + + +class ArrayCat(ArrayQueryString): + def __init__( + self, + array_1: ArrayType, + array_2: ArrayType, + ): + """ + Concatenate two arrays. + + :param array_1: + These values will be at the start of the new array. + :param array_2: + These values will be at the end of the new array. + + """ + for value in (array_1, array_2): + if isinstance(value, Column): + engine_type = value._meta.engine_type + if engine_type not in ("postgres", "cockroach"): + raise ValueError( + "Only Postgres and Cockroach support array " + "concatenation." + ) + + super().__init__("array_cat({}, {})", array_1, array_2) + + +class ArrayAppend(ArrayQueryString): + def __init__(self, array: ArrayType, value: ArrayItemType): + """ + Append an element to the end of an array. + + :param column: + Identifies the column. + :param value: + The value to append. + + """ + if isinstance(array, Column): + engine_type = array._meta.engine_type + if engine_type not in ("postgres", "cockroach"): + raise ValueError( + "Only Postgres and Cockroach support array appending." + ) + + super().__init__("array_append({}, {})", array, value) + + +class ArrayPrepend(ArrayQueryString): + def __init__(self, array: ArrayType, value: ArrayItemType): + """ + Append an element to the beginning of an array. + + :param value: + The value to prepend. + :param column: + Identifies the column. + + """ + if isinstance(array, Column): + engine_type = array._meta.engine_type + if engine_type not in ("postgres", "cockroach"): + raise ValueError( + "Only Postgres and Cockroach support array prepending." + ) + + super().__init__("array_prepend({}, {})", value, array) + + +class ArrayReplace(ArrayQueryString): + def __init__( + self, + array: ArrayType, + old_value: ArrayItemType, + new_value: ArrayItemType, + ): + """ + Replace each array element equal to the given value with a new value. + + :param column: + Identifies the column. + :param old_value: + The old value to be replaced. + :param new_value: + The new value we are replacing with. + + """ + if isinstance(array, Column): + engine_type = array._meta.engine_type + if engine_type not in ("postgres", "cockroach"): + raise ValueError( + "Only Postgres and Cockroach support array substitution." + ) + + super().__init__( + "array_replace({}, {}, {})", array, old_value, new_value + ) + + +class ArrayRemove(ArrayQueryString): + def __init__(self, array: ArrayType, value: ArrayItemType): + """ + Remove all elements equal to the given value + from the array (array must be one-dimensional). + + :param column: + Identifies the column. + :param value: + The value to remove. + + """ + if isinstance(array, Column): + engine_type = array._meta.engine_type + if engine_type not in ("postgres", "cockroach"): + raise ValueError( + "Only Postgres and Cockroach support array removing." + ) + + super().__init__("array_remove({}, {})", array, value) + + +__all__ = ( + "ArrayCat", + "ArrayAppend", + "ArrayPrepend", + "ArrayReplace", + "ArrayRemove", +) diff --git a/piccolo/query/functions/base.py b/piccolo/query/functions/base.py new file mode 100644 index 000000000..807de8365 --- /dev/null +++ b/piccolo/query/functions/base.py @@ -0,0 +1,21 @@ +from typing import Optional, Union + +from piccolo.columns.base import Column +from piccolo.querystring import QueryString + + +class Function(QueryString): + function_name: str + + def __init__( + self, + identifier: Union[Column, QueryString, str], + alias: Optional[str] = None, + ): + alias = alias or self.__class__.__name__.lower() + + super().__init__( + f"{self.function_name}({{}})", + identifier, + alias=alias, + ) diff --git a/piccolo/query/functions/datetime.py b/piccolo/query/functions/datetime.py new file mode 100644 index 000000000..146fc5c42 --- /dev/null +++ b/piccolo/query/functions/datetime.py @@ -0,0 +1,260 @@ +from typing import Literal, Optional, Union, get_args + +from piccolo.columns.base import Column +from piccolo.columns.column_types import ( + Date, + Integer, + Time, + Timestamp, + Timestamptz, +) +from piccolo.querystring import QueryString + +from .type_conversion import Cast + +############################################################################### +# Postgres / Cockroach + +ExtractComponent = Literal[ + "century", + "day", + "decade", + "dow", + "doy", + "epoch", + "hour", + "isodow", + "isoyear", + "julian", + "microseconds", + "millennium", + "milliseconds", + "minute", + "month", + "quarter", + "second", + "timezone", + "timezone_hour", + "timezone_minute", + "week", + "year", +] + + +class Extract(QueryString): + def __init__( + self, + identifier: Union[Date, Time, Timestamp, Timestamptz, QueryString], + datetime_component: ExtractComponent, + alias: Optional[str] = None, + ): + """ + .. note:: This is for Postgres / Cockroach only. + + Extract a date or time component from a ``Date`` / ``Time`` / + ``Timestamp`` / ``Timestamptz`` column. For example, getting the month + from a timestamp: + + .. code-block:: python + + >>> from piccolo.query.functions import Extract + >>> await Concert.select( + ... Extract(Concert.starts, "month", alias="start_month") + ... ) + [{"start_month": 12}] + + :param identifier: + Identifies the column. + :param datetime_component: + The date or time component to extract from the column. + + """ + if datetime_component.lower() not in get_args(ExtractComponent): + raise ValueError("The date time component isn't recognised.") + + super().__init__( + f"EXTRACT({datetime_component} FROM {{}})", + identifier, + alias=alias, + ) + + +############################################################################### +# SQLite + + +class Strftime(QueryString): + def __init__( + self, + identifier: Union[Date, Time, Timestamp, Timestamptz, QueryString], + datetime_format: str, + alias: Optional[str] = None, + ): + """ + .. note:: This is for SQLite only. + + Format a datetime value. For example: + + .. code-block:: python + + >>> from piccolo.query.functions import Strftime + >>> await Concert.select( + ... Strftime(Concert.starts, "%Y", alias="start_year") + ... ) + [{"start_month": "2024"}] + + :param identifier: + Identifies the column. + :param datetime_format: + A string describing the output format (see SQLite's + `documentation `_ + for more info). + + """ + super().__init__( + f"strftime('{datetime_format}', {{}})", + identifier, + alias=alias, + ) + + +############################################################################### +# Database agnostic + + +def _get_engine_type(identifier: Union[Column, QueryString]) -> str: + if isinstance(identifier, Column): + return identifier._meta.engine_type + elif isinstance(identifier, QueryString) and ( + columns := identifier.columns + ): + return columns[0]._meta.engine_type + else: + raise ValueError("Unable to determine the engine type") + + +def _extract_component( + identifier: Union[Date, Time, Timestamp, Timestamptz, QueryString], + sqlite_format: str, + postgres_format: ExtractComponent, + alias: Optional[str], +) -> QueryString: + engine_type = _get_engine_type(identifier=identifier) + + return Cast( + ( + Strftime( + identifier=identifier, + datetime_format=sqlite_format, + ) + if engine_type == "sqlite" + else Extract( + identifier=identifier, + datetime_component=postgres_format, + ) + ), + Integer(), + alias=alias, + ) + + +def Year( + identifier: Union[Date, Timestamp, Timestamptz, QueryString], + alias: Optional[str] = None, +) -> QueryString: + """ + Extract the year as an integer. + """ + return _extract_component( + identifier=identifier, + sqlite_format="%Y", + postgres_format="year", + alias=alias, + ) + + +def Month( + identifier: Union[Date, Timestamp, Timestamptz, QueryString], + alias: Optional[str] = None, +) -> QueryString: + """ + Extract the month as an integer. + """ + return _extract_component( + identifier=identifier, + sqlite_format="%m", + postgres_format="month", + alias=alias, + ) + + +def Day( + identifier: Union[Date, Timestamp, Timestamptz, QueryString], + alias: Optional[str] = None, +) -> QueryString: + """ + Extract the day as an integer. + """ + return _extract_component( + identifier=identifier, + sqlite_format="%d", + postgres_format="day", + alias=alias, + ) + + +def Hour( + identifier: Union[Time, Timestamp, Timestamptz, QueryString], + alias: Optional[str] = None, +) -> QueryString: + """ + Extract the hour as an integer. + """ + return _extract_component( + identifier=identifier, + sqlite_format="%H", + postgres_format="hour", + alias=alias, + ) + + +def Minute( + identifier: Union[Time, Timestamp, Timestamptz, QueryString], + alias: Optional[str] = None, +) -> QueryString: + """ + Extract the minute as an integer. + """ + return _extract_component( + identifier=identifier, + sqlite_format="%M", + postgres_format="minute", + alias=alias, + ) + + +def Second( + identifier: Union[Time, Timestamp, Timestamptz, QueryString], + alias: Optional[str] = None, +) -> QueryString: + """ + Extract the second as an integer. + """ + return _extract_component( + identifier=identifier, + sqlite_format="%S", + postgres_format="second", + alias=alias, + ) + + +__all__ = ( + "Extract", + "Strftime", + "Year", + "Month", + "Day", + "Hour", + "Minute", + "Second", +) diff --git a/piccolo/query/functions/math.py b/piccolo/query/functions/math.py new file mode 100644 index 000000000..e0ebaf70f --- /dev/null +++ b/piccolo/query/functions/math.py @@ -0,0 +1,48 @@ +""" +These functions mirror their counterparts in the Postgresql docs: + +https://www.postgresql.org/docs/current/functions-math.html + +""" + +from .base import Function + + +class Abs(Function): + """ + Absolute value. + """ + + function_name = "ABS" + + +class Ceil(Function): + """ + Nearest integer greater than or equal to argument. + """ + + function_name = "CEIL" + + +class Floor(Function): + """ + Nearest integer less than or equal to argument. + """ + + function_name = "FLOOR" + + +class Round(Function): + """ + Rounds to nearest integer. + """ + + function_name = "ROUND" + + +__all__ = ( + "Abs", + "Ceil", + "Floor", + "Round", +) diff --git a/piccolo/query/functions/string.py b/piccolo/query/functions/string.py new file mode 100644 index 000000000..3aa4a5d45 --- /dev/null +++ b/piccolo/query/functions/string.py @@ -0,0 +1,118 @@ +""" +These functions mirror their counterparts in the Postgresql docs: + +https://www.postgresql.org/docs/current/functions-string.html + +""" + +from typing import Optional, Union + +from piccolo.columns.base import Column +from piccolo.columns.column_types import Text, Varchar +from piccolo.querystring import QueryString + +from .base import Function + + +class Length(Function): + """ + Returns the number of characters in the string. + """ + + function_name = "LENGTH" + + +class Lower(Function): + """ + Converts the string to all lower case, according to the rules of the + database's locale. + """ + + function_name = "LOWER" + + +class Ltrim(Function): + """ + Removes the longest string containing only characters in characters (a + space by default) from the start of string. + """ + + function_name = "LTRIM" + + +class Reverse(Function): + """ + Return reversed string. + + Not supported in SQLite. + + """ + + function_name = "REVERSE" + + +class Rtrim(Function): + """ + Removes the longest string containing only characters in characters (a + space by default) from the end of string. + """ + + function_name = "RTRIM" + + +class Upper(Function): + """ + Converts the string to all upper case, according to the rules of the + database's locale. + """ + + function_name = "UPPER" + + +class Concat(QueryString): + def __init__( + self, + *args: Union[Column, QueryString, str], + alias: Optional[str] = None, + ): + """ + Concatenate multiple values into a single string. + + .. note:: + Null values are ignored, so ``null + '!!!'`` returns ``!!!``, + not ``null``. + + .. warning:: + For SQLite, this is only available in version 3.44.0 and above. + + """ + if len(args) < 2: + raise ValueError("At least two values must be passed in.") + + placeholders = ", ".join("{}" for _ in args) + + processed_args: list[Union[QueryString, Column]] = [] + + for arg in args: + if isinstance(arg, str) or ( + isinstance(arg, Column) + and not isinstance(arg, (Varchar, Text)) + ): + processed_args.append(QueryString("CAST({} AS TEXT)", arg)) + else: + processed_args.append(arg) + + super().__init__( + f"CONCAT({placeholders})", *processed_args, alias=alias + ) + + +__all__ = ( + "Length", + "Lower", + "Ltrim", + "Reverse", + "Rtrim", + "Upper", + "Concat", +) diff --git a/piccolo/query/functions/type_conversion.py b/piccolo/query/functions/type_conversion.py new file mode 100644 index 000000000..1bbb44f72 --- /dev/null +++ b/piccolo/query/functions/type_conversion.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from typing import Optional, Union + +from piccolo.columns.base import Column +from piccolo.custom_types import BasicTypes +from piccolo.querystring import QueryString + + +class Cast(QueryString): + def __init__( + self, + identifier: Union[Column, QueryString, BasicTypes], + as_type: Column, + alias: Optional[str] = None, + ): + """ + Cast a value to a different type. For example:: + + >>> from piccolo.query.functions import Cast + + >>> await Concert.select( + ... Cast(Concert.starts, Time(), alias="start_time") + ... ) + [{"start_time": datetime.time(19, 0)}] + + You may also need ``Cast`` to explicitly tell the database which type + you're sending in the query (though this is an edge case). Here is a + contrived example:: + + >>> from piccolo.query.functions.math import Count + + # This fails with asyncpg: + >>> await Band.select(Count([1,2,3])) + + If we explicitly specify the type of the array, then it works:: + + >>> await Band.select( + ... Count( + ... Cast( + ... [1,2,3], + ... Array(Integer()) + ... ), + ... ) + ... ) + + :param identifier: + Identifies what is being converted (e.g. a column, or a raw value). + :param as_type: + The type to be converted to. + + """ + # Convert `as_type` to a string which can be used in the query. + + if not isinstance(as_type, Column): + raise ValueError("The `as_type` value must be a Column instance.") + + # We need to give the column a reference to a table, and hence + # the database engine, as the column type is sometimes dependent + # on which database is being used. + + from piccolo.table import Table, create_table_class + + table: Optional[type[Table]] = None + + if isinstance(identifier, Column): + table = identifier._meta.table + elif isinstance(identifier, QueryString): + table = ( + identifier.columns[0]._meta.table + if identifier.columns + else None + ) + + as_type._meta.table = table or create_table_class("Table") + as_type_string = as_type.column_type + + ####################################################################### + # Preserve the original alias from the column. + + if isinstance(identifier, Column): + alias = ( + alias + or identifier._alias + or identifier._meta.get_default_alias() + ) + + ####################################################################### + + super().__init__( + f"CAST({{}} AS {as_type_string})", + identifier, + alias=alias, + ) + + +__all__ = ("Cast",) diff --git a/piccolo/query/methods/__init__.py b/piccolo/query/methods/__init__.py index aaf1d41b1..f4b9a59f1 100644 --- a/piccolo/query/methods/__init__.py +++ b/piccolo/query/methods/__init__.py @@ -8,6 +8,24 @@ from .insert import Insert from .objects import Objects from .raw import Raw -from .select import Avg, Max, Min, Select, Sum +from .refresh import Refresh +from .select import Select from .table_exists import TableExists from .update import Update + +__all__ = ( + "Alter", + "Count", + "Create", + "CreateIndex", + "Delete", + "DropIndex", + "Exists", + "Insert", + "Objects", + "Raw", + "Refresh", + "Select", + "TableExists", + "Update", +) diff --git a/piccolo/query/methods/alter.py b/piccolo/query/methods/alter.py index 5e9b3b5be..35774acd3 100644 --- a/piccolo/query/methods/alter.py +++ b/piccolo/query/methods/alter.py @@ -1,29 +1,29 @@ from __future__ import annotations import itertools -import typing as t +from collections.abc import Sequence from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from piccolo.columns.base import Column from piccolo.columns.column_types import ForeignKey, Numeric, Varchar -from piccolo.query.base import Query -from piccolo.querystring import QueryString +from piccolo.query.base import DDL from piccolo.utils.warnings import Level, colored_warning -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from piccolo.columns.base import OnDelete, OnUpdate from piccolo.table import Table -@dataclass class AlterStatement: - __slots__ = tuple() # type: ignore + __slots__ = () # type: ignore - def querystring(self) -> QueryString: + @property + def ddl(self) -> str: raise NotImplementedError() def __str__(self) -> str: - return self.querystring.__str__() + return self.ddl @dataclass @@ -33,22 +33,34 @@ class RenameTable(AlterStatement): new_name: str @property - def querystring(self) -> QueryString: - return QueryString(f"RENAME TO {self.new_name}") + def ddl(self) -> str: + return f"RENAME TO {self.new_name}" + + +@dataclass +class RenameConstraint(AlterStatement): + __slots__ = ("old_name", "new_name") + + old_name: str + new_name: str + + @property + def ddl(self) -> str: + return f"RENAME CONSTRAINT {self.old_name} TO {self.new_name}" @dataclass class AlterColumnStatement(AlterStatement): __slots__ = ("column",) - column: t.Union[Column, str] + column: Union[Column, str] @property def column_name(self) -> str: if isinstance(self.column, str): return self.column elif isinstance(self.column, Column): - return self.column._meta.name + return self.column._meta.db_column_name else: raise ValueError("Unrecognised column type") @@ -60,37 +72,35 @@ class RenameColumn(AlterColumnStatement): new_name: str @property - def querystring(self) -> QueryString: - return QueryString( - f"RENAME COLUMN {self.column_name} TO {self.new_name}" - ) + def ddl(self) -> str: + return f'RENAME COLUMN "{self.column_name}" TO "{self.new_name}"' @dataclass class DropColumn(AlterColumnStatement): @property - def querystring(self) -> QueryString: - return QueryString(f"DROP COLUMN {self.column_name}") + def ddl(self) -> str: + return f'DROP COLUMN "{self.column_name}"' @dataclass class AddColumn(AlterColumnStatement): - __slots__ = ("column", "name") + __slots__ = ("name",) column: Column name: str @property - def querystring(self) -> QueryString: + def ddl(self) -> str: self.column._meta.name = self.name - return QueryString("ADD COLUMN {}", self.column.querystring) + return f"ADD COLUMN {self.column.ddl}" @dataclass class DropDefault(AlterColumnStatement): @property - def querystring(self) -> QueryString: - return QueryString(f"ALTER COLUMN {self.column_name} DROP DEFAULT") + def ddl(self) -> str: + return f'ALTER COLUMN "{self.column_name}" DROP DEFAULT' @dataclass @@ -105,35 +115,33 @@ class SetColumnType(AlterStatement): old_column: Column new_column: Column - using_expression: t.Optional[str] = None + using_expression: Optional[str] = None @property - def querystring(self) -> QueryString: + def ddl(self) -> str: if self.new_column._meta._table is None: self.new_column._meta._table = self.old_column._meta.table - column_name = self.old_column._meta.name + column_name = self.old_column._meta.db_column_name query = ( - f"ALTER COLUMN {column_name} TYPE {self.new_column.column_type}" + f'ALTER COLUMN "{column_name}" TYPE {self.new_column.column_type}' ) if self.using_expression is not None: query += f" USING {self.using_expression}" - return QueryString(query) + return query @dataclass class SetDefault(AlterColumnStatement): - __slots__ = ("column", "value") + __slots__ = ("value",) column: Column - value: t.Any + value: Any @property - def querystring(self) -> QueryString: + def ddl(self) -> str: sql_value = self.column.get_sql_value(self.value) - return QueryString( - f"ALTER COLUMN {self.column_name} SET DEFAULT {sql_value}" - ) + return f'ALTER COLUMN "{self.column_name}" SET DEFAULT {sql_value}' @dataclass @@ -143,19 +151,18 @@ class SetUnique(AlterColumnStatement): boolean: bool @property - def querystring(self) -> QueryString: + def ddl(self) -> str: if self.boolean: - return QueryString(f"ADD UNIQUE ({self.column_name})") - else: - if isinstance(self.column, str): - raise ValueError( - "Removing a unique constraint requires a Column instance " - "to be passed as the column arg instead of a string." - ) - tablename = self.column._meta.table._meta.tablename - column_name = self.column_name - key = f"{tablename}_{column_name}_key" - return QueryString(f'DROP CONSTRAINT "{key}"') + return f'ADD UNIQUE ("{self.column_name}")' + if isinstance(self.column, str): + raise ValueError( + "Removing a unique constraint requires a Column instance " + "to be passed as the column arg instead of a string." + ) + tablename = self.column._meta.table._meta.tablename + column_name = self.column_name + key = f"{tablename}_{column_name}_key" + return f'DROP CONSTRAINT "{key}"' @dataclass @@ -165,27 +172,22 @@ class SetNull(AlterColumnStatement): boolean: bool @property - def querystring(self) -> QueryString: + def ddl(self) -> str: if self.boolean: - return QueryString( - f"ALTER COLUMN {self.column_name} DROP NOT NULL" - ) + return f'ALTER COLUMN "{self.column_name}" DROP NOT NULL' else: - return QueryString(f"ALTER COLUMN {self.column_name} SET NOT NULL") + return f'ALTER COLUMN "{self.column_name}" SET NOT NULL' @dataclass class SetLength(AlterColumnStatement): - __slots__ = ("length",) length: int @property - def querystring(self) -> QueryString: - return QueryString( - f"ALTER COLUMN {self.column_name} TYPE VARCHAR({self.length})" - ) + def ddl(self) -> str: + return f'ALTER COLUMN "{self.column_name}" TYPE VARCHAR({self.length})' @dataclass @@ -195,10 +197,8 @@ class DropConstraint(AlterStatement): constraint_name: str @property - def querystring(self) -> QueryString: - return QueryString( - f"DROP CONSTRAINT IF EXISTS {self.constraint_name}", - ) + def ddl(self) -> str: + return f"DROP CONSTRAINT IF EXISTS {self.constraint_name}" @dataclass @@ -207,6 +207,7 @@ class AddForeignKeyConstraint(AlterStatement): "constraint_name", "foreign_key_column_name", "referenced_table_name", + "referenced_column_name", "on_delete", "on_update", ) @@ -214,74 +215,81 @@ class AddForeignKeyConstraint(AlterStatement): constraint_name: str foreign_key_column_name: str referenced_table_name: str - on_delete: t.Optional[OnDelete] - on_update: t.Optional[OnUpdate] - referenced_column_name: str = "id" + referenced_column_name: str + on_delete: Optional[OnDelete] + on_update: Optional[OnUpdate] @property - def querystring(self) -> QueryString: + def ddl(self) -> str: query = ( - f"ADD CONSTRAINT {self.constraint_name} FOREIGN KEY " - f"({self.foreign_key_column_name}) REFERENCES " - f"{self.referenced_table_name} ({self.referenced_column_name})" + f'ADD CONSTRAINT "{self.constraint_name}" FOREIGN KEY ' + f'("{self.foreign_key_column_name}") REFERENCES ' + f'"{self.referenced_table_name}" ("{self.referenced_column_name}")' ) if self.on_delete: query += f" ON DELETE {self.on_delete.value}" if self.on_update: query += f" ON UPDATE {self.on_update.value}" - return QueryString(query) + return query @dataclass class SetDigits(AlterColumnStatement): - __slots__ = ("digits", "column_type") - digits: t.Optional[t.Tuple[int, int]] + digits: Optional[tuple[int, int]] column_type: str @property - def querystring(self) -> QueryString: - if self.digits is not None: - precision = self.digits[0] - scale = self.digits[1] - return QueryString( - f"ALTER COLUMN {self.column_name} TYPE " - f"{self.column_type}({precision}, {scale})" - ) - else: - return QueryString( - f"ALTER COLUMN {self.column_name} TYPE {self.column_type}", - ) + def ddl(self) -> str: + if self.digits is None: + return f'ALTER COLUMN "{self.column_name}" TYPE {self.column_type}' + + precision = self.digits[0] + scale = self.digits[1] + return ( + f'ALTER COLUMN "{self.column_name}" TYPE ' + f"{self.column_type}({precision}, {scale})" + ) + + +@dataclass +class SetSchema(AlterStatement): + __slots__ = ("schema_name",) + + schema_name: str + + @property + def ddl(self) -> str: + return f'SET SCHEMA "{self.schema_name}"' @dataclass class DropTable: - tablename: str + table: type[Table] cascade: bool if_exists: bool @property - def querystring(self) -> QueryString: + def ddl(self) -> str: query = "DROP TABLE" if self.if_exists: query += " IF EXISTS" - query += f" {self.tablename}" + query += f" {self.table._meta.get_formatted_tablename()}" if self.cascade: query += " CASCADE" - return QueryString(query) - + return query -class Alter(Query): +class Alter(DDL): __slots__ = ( - "_add_foreign_key_constraint", "_add", - "_drop_contraint", + "_add_foreign_key_constraint", + "_drop_constraint", "_drop_default", "_drop_table", "_drop", @@ -292,44 +300,63 @@ class Alter(Query): "_set_digits", "_set_length", "_set_null", + "_set_schema", "_set_unique", + "_rename_constraint", ) - def __init__(self, table: t.Type[Table], **kwargs): + def __init__(self, table: type[Table], **kwargs): super().__init__(table, **kwargs) - self._add_foreign_key_constraint: t.List[AddForeignKeyConstraint] = [] - self._add: t.List[AddColumn] = [] - self._drop_contraint: t.List[DropConstraint] = [] - self._drop_default: t.List[DropDefault] = [] - self._drop_table: t.Optional[DropTable] = None - self._drop: t.List[DropColumn] = [] - self._rename_columns: t.List[RenameColumn] = [] - self._rename_table: t.List[RenameTable] = [] - self._set_column_type: t.List[SetColumnType] = [] - self._set_default: t.List[SetDefault] = [] - self._set_digits: t.List[SetDigits] = [] - self._set_length: t.List[SetLength] = [] - self._set_null: t.List[SetNull] = [] - self._set_unique: t.List[SetUnique] = [] - - def add_column(self, name: str, column: Column) -> Alter: + self._add_foreign_key_constraint: list[AddForeignKeyConstraint] = [] + self._add: list[AddColumn] = [] + self._drop_constraint: list[DropConstraint] = [] + self._drop_default: list[DropDefault] = [] + self._drop_table: Optional[DropTable] = None + self._drop: list[DropColumn] = [] + self._rename_columns: list[RenameColumn] = [] + self._rename_table: list[RenameTable] = [] + self._set_column_type: list[SetColumnType] = [] + self._set_default: list[SetDefault] = [] + self._set_digits: list[SetDigits] = [] + self._set_length: list[SetLength] = [] + self._set_null: list[SetNull] = [] + self._set_schema: list[SetSchema] = [] + self._set_unique: list[SetUnique] = [] + self._rename_constraint: list[RenameConstraint] = [] + + def add_column(self: Self, name: str, column: Column) -> Self: """ - Band.alter().add_column(‘members’, Integer()) + Add a column to the table:: + + >>> await Band.alter().add_column('members', Integer()) + """ column._meta._table = self.table + column._meta._name = name + column._meta.db_column_name = name + + if isinstance(column, ForeignKey): + column._setup(table_class=self.table) + self._add.append(AddColumn(column, name)) return self - def drop_column(self, column: t.Union[str, Column]) -> Alter: + def drop_column(self, column: Union[str, Column]) -> Alter: """ - Band.alter().drop_column(Band.popularity) + Drop a column from the table:: + + >>> await Band.alter().drop_column(Band.popularity) + """ self._drop.append(DropColumn(column)) return self - def drop_default(self, column: t.Union[str, Column]) -> Alter: + def drop_default(self, column: Union[str, Column]) -> Alter: """ - Band.alter().drop_default(Band.popularity) + Drop the default from a column:: + + >>> await Band.alter().drop_default(Band.popularity) + """ self._drop_default.append(DropDefault(column=column)) return self @@ -338,10 +365,13 @@ def drop_table( self, cascade: bool = False, if_exists: bool = False ) -> Alter: """ - Band.alter().drop_table() + Drop the table:: + + >>> await Band.alter().drop_table() + """ self._drop_table = DropTable( - tablename=self.table._meta.tablename, + table=self.table, cascade=cascade, if_exists=if_exists, ) @@ -349,18 +379,45 @@ def drop_table( def rename_table(self, new_name: str) -> Alter: """ - Band.alter().rename_table('musical_group') + Rename the table:: + + >>> await Band.alter().rename_table('musical_group') + """ # We override the existing one rather than appending. self._rename_table = [RenameTable(new_name=new_name)] return self + def rename_constraint(self, old_name: str, new_name: str) -> Alter: + """ + Rename a constraint on the table:: + + >>> await Band.alter().rename_constraint( + ... 'old_constraint_name', + ... 'new_constraint_name', + ... ) + + """ + self._rename_constraint = [ + RenameConstraint( + old_name=old_name, + new_name=new_name, + ) + ] + return self + def rename_column( - self, column: t.Union[str, Column], new_name: str + self, column: Union[str, Column], new_name: str ) -> Alter: """ - Band.alter().rename_column(Band.popularity, ‘rating’) - Band.alter().rename_column('popularity', ‘rating’) + Rename a column on the table:: + + # Specify the column with a `Column` instance: + >>> await Band.alter().rename_column(Band.popularity, 'rating') + + # Or by name: + >>> await Band.alter().rename_column('popularity', 'rating') + """ self._rename_columns.append(RenameColumn(column, new_name)) return self @@ -369,10 +426,19 @@ def set_column_type( self, old_column: Column, new_column: Column, - using_expression: t.Optional[str] = None, + using_expression: Optional[str] = None, ) -> Alter: """ - Change the type of a column. + Change the type of a column:: + + >>> await Band.alter().set_column_type(Band.popularity, BigInt()) + + :param using_expression: + When changing a column's type, the database doesn't always know how + to convert the existing data in that column to the new type. You + can provide a hint to the database on what to do. For example + ``'name::integer'``. + """ self._set_column_type.append( SetColumnType( @@ -383,42 +449,56 @@ def set_column_type( ) return self - def set_default(self, column: Column, value: t.Any) -> Alter: + def set_default(self, column: Column, value: Any) -> Alter: """ - Set the default for a column. + Set the default for a column:: + + >>> await Band.alter().set_default(Band.popularity, 0) - Band.alter().set_default(Band.popularity, 0) """ self._set_default.append(SetDefault(column=column, value=value)) return self def set_null( - self, column: t.Union[str, Column], boolean: bool = True + self, column: Union[str, Column], boolean: bool = True ) -> Alter: """ - Band.alter().set_null(Band.name, True) - Band.alter().set_null('name', True) + Change a column to be nullable or not:: + + # Specify the column using a `Column` instance: + >>> await Band.alter().set_null(Band.name, True) + + # Or using a string: + >>> await Band.alter().set_null('name', True) + """ self._set_null.append(SetNull(column, boolean)) return self def set_unique( - self, column: t.Union[str, Column], boolean: bool = True + self, column: Union[str, Column], boolean: bool = True ) -> Alter: """ - Band.alter().set_unique(Band.name, True) - Band.alter().set_unique('name', True) + Make a column unique or not:: + + # Specify the column using a `Column` instance: + >>> await Band.alter().set_unique(Band.name, True) + + # Or using a string: + >>> await Band.alter().set_unique('name', True) + """ self._set_unique.append(SetUnique(column, boolean)) return self - def set_length(self, column: t.Union[str, Varchar], length: int) -> Alter: + def set_length(self, column: Union[str, Varchar], length: int) -> Alter: """ Change the max length of a varchar column. Unfortunately, this isn't supported by SQLite, but SQLite also doesn't enforce any length limits - on varchar columns anyway. + on varchar columns anyway:: + + >>> await Band.alter().set_length('name', 512) - Band.alter().set_length('name', 512) """ if self.engine_type == "sqlite": colored_warning( @@ -439,63 +519,84 @@ def set_length(self, column: t.Union[str, Varchar], length: int) -> Alter: self._set_length.append(SetLength(column, length)) return self - def _get_constraint_name(self, column: t.Union[str, ForeignKey]) -> str: + def _get_constraint_name(self, column: Union[str, ForeignKey]) -> str: column_name = AlterColumnStatement(column=column).column_name tablename = self.table._meta.tablename - constraint_name = f"{tablename}_{column_name}_fk" - return constraint_name + return f"{tablename}_{column_name}_fkey" def drop_constraint(self, constraint_name: str) -> Alter: - self._drop_contraint.append( + self._drop_constraint.append( DropConstraint(constraint_name=constraint_name) ) return self def drop_foreign_key_constraint( - self, column: t.Union[str, ForeignKey] + self, column: Union[str, ForeignKey] ) -> Alter: constraint_name = self._get_constraint_name(column=column) - return self.drop_constraint(constraint_name=constraint_name) + self._drop_constraint.append( + DropConstraint(constraint_name=constraint_name) + ) + return self def add_foreign_key_constraint( self, - column: t.Union[str, ForeignKey], - referenced_table_name: str, - on_delete: t.Optional[OnDelete] = None, - on_update: t.Optional[OnUpdate] = None, - referenced_column_name: str = "id", + column: Union[str, ForeignKey], + referenced_table_name: Optional[str] = None, + referenced_column_name: Optional[str] = None, + constraint_name: Optional[str] = None, + on_delete: Optional[OnDelete] = None, + on_update: Optional[OnUpdate] = None, ) -> Alter: """ - This will add a new foreign key constraint. + Add a new foreign key constraint:: + + >>> await Band.alter().add_foreign_key_constraint( + ... Band.manager, + ... on_delete=OnDelete.cascade + ... ) - Band.alter().add_foreign_key_constraint( - Band.manager, - referenced_table_name='manager', - on_delete=OnDelete.cascade - ) """ - constraint_name = self._get_constraint_name(column=column) + constraint_name = constraint_name or self._get_constraint_name( + column=column + ) column_name = AlterColumnStatement(column=column).column_name + if referenced_column_name is None: + if isinstance(column, ForeignKey): + referenced_column_name = ( + column._foreign_key_meta.resolved_target_column._meta.db_column_name # noqa: E501 + ) + else: + raise ValueError("Please pass in `referenced_column_name`.") + + if referenced_table_name is None: + if isinstance(column, ForeignKey): + referenced_table_name = ( + column._foreign_key_meta.resolved_references._meta.tablename # noqa: E501 + ) + else: + raise ValueError("Please pass in `referenced_table_name`.") + self._add_foreign_key_constraint.append( AddForeignKeyConstraint( constraint_name=constraint_name, foreign_key_column_name=column_name, referenced_table_name=referenced_table_name, + referenced_column_name=referenced_column_name, on_delete=on_delete, on_update=on_update, - referenced_column_name=referenced_column_name, ) ) return self def set_digits( self, - column: t.Union[str, Numeric], - digits: t.Optional[t.Tuple[int, int]], + column: Union[str, Numeric], + digits: Optional[tuple[int, int]], ) -> Alter: """ - Alter the precision and scale for a Numeric column. + Alter the precision and scale for a ``Numeric`` column. """ column_type = ( column.__class__.__name__.upper() @@ -511,20 +612,34 @@ def set_digits( ) return self + def set_schema(self, schema_name: str) -> Alter: + """ + Move the table to a different schema. + + :param schema_name: + The schema to move the table to. + + """ + self._set_schema.append(SetSchema(schema_name=schema_name)) + return self + @property - def default_querystrings(self) -> t.Sequence[QueryString]: + def default_ddl(self) -> Sequence[str]: if self._drop_table is not None: - return [self._drop_table.querystring] + return [self._drop_table.ddl] - query = f"ALTER TABLE {self.table._meta.tablename}" + query = f"ALTER TABLE {self.table._meta.get_formatted_tablename()}" alterations = [ - i.querystring + i.ddl for i in itertools.chain( self._add, + self._add_foreign_key_constraint, self._rename_columns, self._rename_table, + self._rename_constraint, self._drop, + self._drop_constraint, self._drop_default, self._set_column_type, self._set_unique, @@ -532,15 +647,18 @@ def default_querystrings(self) -> t.Sequence[QueryString]: self._set_length, self._set_default, self._set_digits, + self._set_schema, ) ] if self.engine_type == "sqlite": # Can only perform one alter statement at a time. - query += " {}" - return [QueryString(query, i) for i in alterations] + return [f"{query} {i}" for i in alterations] # Postgres can perform them all at once: - query += ",".join([" {}" for i in alterations]) + query += ",".join(f" {i}" for i in alterations) + + return [query] + - return [QueryString(query, *alterations)] +Self = TypeVar("Self", bound=Alter) diff --git a/piccolo/query/methods/count.py b/piccolo/query/methods/count.py index 100bf43a0..3c61ae9b6 100644 --- a/piccolo/query/methods/count.py +++ b/piccolo/query/methods/count.py @@ -1,39 +1,62 @@ from __future__ import annotations -import typing as t +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional, TypeVar, Union from piccolo.custom_types import Combinable from piccolo.query.base import Query +from piccolo.query.functions.aggregate import Count as CountFunction from piccolo.query.mixins import WhereDelegate from piccolo.querystring import QueryString -from .select import Select - -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover + from piccolo.columns import Column from piccolo.table import Table class Count(Query): - __slots__ = ("where_delegate",) - def __init__(self, table: t.Type[Table], **kwargs): + __slots__ = ("where_delegate", "column", "_distinct") + + def __init__( + self, + table: type[Table], + column: Optional[Column] = None, + distinct: Optional[Sequence[Column]] = None, + **kwargs, + ): super().__init__(table, **kwargs) + self.column = column + self._distinct = distinct self.where_delegate = WhereDelegate() - def where(self, where: Combinable) -> Count: - self.where_delegate.where(where) + ########################################################################### + # Clauses + + def where(self: Self, *where: Union[Combinable, QueryString]) -> Self: + self.where_delegate.where(*where) return self + def distinct(self: Self, columns: Optional[Sequence[Column]]) -> Self: + self._distinct = columns + return self + + ########################################################################### + async def response_handler(self, response) -> bool: return response[0]["count"] @property - def default_querystrings(self) -> t.Sequence[QueryString]: - select = Select(self.table) - select.where_delegate._where = self.where_delegate._where - return [ - QueryString( - 'SELECT COUNT(*) AS "count" FROM ({}) AS "subquery"', - select.querystrings[0], - ) - ] + def default_querystrings(self) -> Sequence[QueryString]: + table: type[Table] = self.table + + query = table.select( + CountFunction(column=self.column, distinct=self._distinct) + ) + + query.where_delegate._where = self.where_delegate._where + + return query.querystrings + + +Self = TypeVar("Self", bound=Count) diff --git a/piccolo/query/methods/create.py b/piccolo/query/methods/create.py index 1b0108bda..592cc9f05 100644 --- a/piccolo/query/methods/create.py +++ b/piccolo/query/methods/create.py @@ -1,35 +1,69 @@ from __future__ import annotations -import typing as t +from collections.abc import Sequence +from typing import TYPE_CHECKING -from piccolo.query.base import Query +from piccolo.query.base import DDL from piccolo.query.methods.create_index import CreateIndex -from piccolo.querystring import QueryString -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from piccolo.table import Table -class Create(Query): +class Create(DDL): """ Creates a database table. """ - __slots__ = ("if_not_exists", "only_default_columns") + __slots__ = ("if_not_exists", "only_default_columns", "auto_create_schema") def __init__( self, - table: t.Type[Table], + table: type[Table], if_not_exists: bool = False, only_default_columns: bool = False, + auto_create_schema: bool = True, **kwargs, ): + """ + :param table: + The table to create. + :param if_not_exists: + If ``True``, no error will be raised if this table already exists. + :param only_default_columns: + If ``True``, just the basic table and default primary key are + created, rather than all columns. Not typically needed. + :param auto_create_schema: + If the table belongs to a database schema, then make sure the + schema exists before creating the table. + + """ super().__init__(table, **kwargs) self.if_not_exists = if_not_exists self.only_default_columns = only_default_columns + self.auto_create_schema = auto_create_schema @property - def default_querystrings(self) -> t.Sequence[QueryString]: + def default_ddl(self) -> Sequence[str]: + ddl: list[str] = [] + + schema_name = self.table._meta.schema + if ( + self.auto_create_schema + and schema_name is not None + and schema_name != "public" + and self.engine_type != "sqlite" + ): + from piccolo.schema import CreateSchema + + ddl.append( + CreateSchema( + schema_name=schema_name, + if_not_exists=True, + db=self.table._meta.db, + ).ddl + ) + prefix = "CREATE TABLE" if self.if_not_exists: prefix += " IF NOT EXISTS" @@ -39,21 +73,19 @@ def default_querystrings(self) -> t.Sequence[QueryString]: else: columns = self.table._meta.columns - base = f"{prefix} {self.table._meta.tablename}" - columns_sql = ", ".join(["{}" for i in columns]) - query = f"{base} ({columns_sql})" - create_table = QueryString(query, *[i.querystring for i in columns]) + base = f"{prefix} {self.table._meta.get_formatted_tablename()}" + columns_sql = ", ".join(i.ddl for i in columns) + ddl.append(f"{base} ({columns_sql})") - create_indexes: t.List[QueryString] = [] for column in columns: if column._meta.index is True: - create_indexes.extend( + ddl.extend( CreateIndex( table=self.table, columns=[column], method=column._meta.index_method, if_not_exists=self.if_not_exists, - ).querystrings + ).ddl ) - return [create_table] + create_indexes + return ddl diff --git a/piccolo/query/methods/create_index.py b/piccolo/query/methods/create_index.py index 3fc4685aa..64ae4b4d8 100644 --- a/piccolo/query/methods/create_index.py +++ b/piccolo/query/methods/create_index.py @@ -1,21 +1,21 @@ from __future__ import annotations -import typing as t +from collections.abc import Sequence +from typing import TYPE_CHECKING, Union from piccolo.columns import Column from piccolo.columns.indexes import IndexMethod -from piccolo.query.base import Query -from piccolo.querystring import QueryString +from piccolo.query.base import DDL -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from piccolo.table import Table -class CreateIndex(Query): +class CreateIndex(DDL): def __init__( self, - table: t.Type[Table], - columns: t.List[t.Union[Column, str]], + table: type[Table], + columns: Union[list[Column], list[str]], method: IndexMethod = IndexMethod.btree, if_not_exists: bool = False, **kwargs, @@ -26,9 +26,10 @@ def __init__( super().__init__(table, **kwargs) @property - def column_names(self) -> t.List[str]: + def column_names(self) -> list[str]: return [ - i._meta.name if isinstance(i, Column) else i for i in self.columns + i._meta.db_column_name if isinstance(i, Column) else i + for i in self.columns ] @property @@ -39,32 +40,36 @@ def prefix(self) -> str: return prefix @property - def postgres_querystrings(self) -> t.Sequence[QueryString]: + def postgres_ddl(self) -> Sequence[str]: column_names = self.column_names index_name = self.table._get_index_name(column_names) - tablename = self.table._meta.tablename + tablename = self.table._meta.get_formatted_tablename() method_name = self.method.value - column_names_str = ", ".join(column_names) + column_names_str = ", ".join([f'"{i}"' for i in self.column_names]) return [ - QueryString( + ( f"{self.prefix} {index_name} ON {tablename} USING " f"{method_name} ({column_names_str})" ) ] @property - def sqlite_querystrings(self) -> t.Sequence[QueryString]: + def cockroach_ddl(self) -> Sequence[str]: + return self.postgres_ddl + + @property + def sqlite_ddl(self) -> Sequence[str]: column_names = self.column_names index_name = self.table._get_index_name(column_names) - tablename = self.table._meta.tablename + tablename = self.table._meta.get_formatted_tablename() method_name = self.method.value if method_name != "btree": raise ValueError("SQLite only support btree indexes.") - column_names_str = ", ".join(column_names) + column_names_str = ", ".join([f'"{i}"' for i in self.column_names]) return [ - QueryString( + ( f"{self.prefix} {index_name} ON {tablename} " f"({column_names_str})" ) diff --git a/piccolo/query/methods/delete.py b/piccolo/query/methods/delete.py index be95a0c01..ea7d56a7e 100644 --- a/piccolo/query/methods/delete.py +++ b/piccolo/query/methods/delete.py @@ -1,13 +1,15 @@ from __future__ import annotations -import typing as t +from collections.abc import Sequence +from typing import TYPE_CHECKING, TypeVar, Union from piccolo.custom_types import Combinable from piccolo.query.base import Query -from piccolo.query.mixins import WhereDelegate +from piccolo.query.mixins import ReturningDelegate, WhereDelegate from piccolo.querystring import QueryString -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover + from piccolo.columns import Column from piccolo.table import Table @@ -17,15 +19,24 @@ class DeletionError(Exception): class Delete(Query): - __slots__ = ("force", "where_delegate") + __slots__ = ( + "force", + "returning_delegate", + "where_delegate", + ) - def __init__(self, table: t.Type[Table], force: bool = False, **kwargs): + def __init__(self, table: type[Table], force: bool = False, **kwargs): super().__init__(table, **kwargs) self.force = force + self.returning_delegate = ReturningDelegate() self.where_delegate = WhereDelegate() - def where(self, where: Combinable) -> Delete: - self.where_delegate.where(where) + def where(self: Self, *where: Union[Combinable, QueryString]) -> Self: + self.where_delegate.where(*where) + return self + + def returning(self: Self, *columns: Column) -> Self: + self.returning_delegate.returning(columns) return self def _validate(self): @@ -37,14 +48,31 @@ def _validate(self): classname = self.table.__name__ raise DeletionError( "Do you really want to delete all the data from " - f"{classname}? If so, use {classname}.delete(force=True)." + f"{classname}? If so, use {classname}.delete(force=True). " + "Otherwise, add a where clause." ) @property - def default_querystrings(self) -> t.Sequence[QueryString]: - query = f"DELETE FROM {self.table._meta.tablename}" + def default_querystrings(self) -> Sequence[QueryString]: + query = f"DELETE FROM {self.table._meta.get_formatted_tablename()}" + + querystring = QueryString(query) + if self.where_delegate._where: - query += " WHERE {}" - return [QueryString(query, self.where_delegate._where.querystring)] - else: - return [QueryString(query)] + querystring = QueryString( + "{} WHERE {}", + querystring, + self.where_delegate._where.querystring_for_update_and_delete, + ) + + if self.returning_delegate._returning: + querystring = QueryString( + "{}{}", + querystring, + self.returning_delegate._returning.querystring, + ) + + return [querystring] + + +Self = TypeVar("Self", bound=Delete) diff --git a/piccolo/query/methods/drop_index.py b/piccolo/query/methods/drop_index.py index 437728437..1b2d9f082 100644 --- a/piccolo/query/methods/drop_index.py +++ b/piccolo/query/methods/drop_index.py @@ -1,20 +1,21 @@ from __future__ import annotations -import typing as t +from collections.abc import Sequence +from typing import TYPE_CHECKING, Union from piccolo.columns.base import Column from piccolo.query.base import Query from piccolo.querystring import QueryString -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from piccolo.table import Table class DropIndex(Query): def __init__( self, - table: t.Type[Table], - columns: t.List[t.Union[Column, str]], + table: type[Table], + columns: Union[list[Column], list[str]], if_exists: bool = True, **kwargs, ): @@ -23,13 +24,13 @@ def __init__( super().__init__(table, **kwargs) @property - def column_names(self) -> t.List[str]: + def column_names(self) -> list[str]: return [ i._meta.name if isinstance(i, Column) else i for i in self.columns ] @property - def default_querystrings(self) -> t.Sequence[QueryString]: + def default_querystrings(self) -> Sequence[QueryString]: column_names = self.column_names index_name = self.table._get_index_name(column_names) query = "DROP INDEX" diff --git a/piccolo/query/methods/exists.py b/piccolo/query/methods/exists.py index 09fc40071..d6a346ac9 100644 --- a/piccolo/query/methods/exists.py +++ b/piccolo/query/methods/exists.py @@ -1,28 +1,24 @@ from __future__ import annotations -import typing as t -from dataclasses import dataclass +from collections.abc import Sequence +from typing import TypeVar, Union -from piccolo.custom_types import Combinable +from piccolo.custom_types import Combinable, TableInstance from piccolo.query.base import Query from piccolo.query.methods.select import Select from piccolo.query.mixins import WhereDelegate from piccolo.querystring import QueryString -if t.TYPE_CHECKING: # pragma: no cover - from piccolo.table import Table - -@dataclass -class Exists(Query): +class Exists(Query[TableInstance, bool]): __slots__ = ("where_delegate",) - def __init__(self, table: t.Type[Table], **kwargs): + def __init__(self, table: type[TableInstance], **kwargs): super().__init__(table, **kwargs) self.where_delegate = WhereDelegate() - def where(self, where: Combinable) -> Exists: - self.where_delegate.where(where) + def where(self: Self, *where: Union[Combinable, QueryString]) -> Self: + self.where_delegate.where(*where) return self async def response_handler(self, response) -> bool: @@ -30,7 +26,7 @@ async def response_handler(self, response) -> bool: return bool(response[0]["exists"]) @property - def default_querystrings(self) -> t.Sequence[QueryString]: + def default_querystrings(self) -> Sequence[QueryString]: select = Select(table=self.table) select.where_delegate._where = self.where_delegate._where return [ @@ -38,3 +34,6 @@ def default_querystrings(self) -> t.Sequence[QueryString]: 'SELECT EXISTS({}) AS "exists"', select.querystrings[0] ) ] + + +Self = TypeVar("Self", bound=Exists) diff --git a/piccolo/query/methods/indexes.py b/piccolo/query/methods/indexes.py index 6ab4dbeae..c5c8b8be7 100644 --- a/piccolo/query/methods/indexes.py +++ b/piccolo/query/methods/indexes.py @@ -1,6 +1,6 @@ from __future__ import annotations -import typing as t +from collections.abc import Sequence from piccolo.query.base import Query from piccolo.querystring import QueryString @@ -12,17 +12,21 @@ class Indexes(Query): """ @property - def postgres_querystrings(self) -> t.Sequence[QueryString]: + def postgres_querystrings(self) -> Sequence[QueryString]: return [ QueryString( "SELECT indexname AS name FROM pg_indexes " "WHERE tablename = {}", - self.table._meta.tablename, + self.table._meta.get_formatted_tablename(quoted=False), ) ] @property - def sqlite_querystrings(self) -> t.Sequence[QueryString]: + def cockroach_querystrings(self) -> Sequence[QueryString]: + return self.postgres_querystrings + + @property + def sqlite_querystrings(self) -> Sequence[QueryString]: tablename = self.table._meta.tablename return [QueryString(f"PRAGMA index_list({tablename})")] diff --git a/piccolo/query/methods/insert.py b/piccolo/query/methods/insert.py index 63d6e0fe5..f9bce9516 100644 --- a/piccolo/query/methods/insert.py +++ b/piccolo/query/methods/insert.py @@ -1,30 +1,94 @@ from __future__ import annotations -import typing as t -from dataclasses import dataclass +from collections.abc import Sequence +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Literal, + Optional, + TypeVar, + Union, +) +from piccolo.custom_types import Combinable, TableInstance from piccolo.query.base import Query -from piccolo.query.mixins import AddDelegate +from piccolo.query.mixins import ( + AddDelegate, + OnConflictAction, + OnConflictDelegate, + ReturningDelegate, +) from piccolo.querystring import QueryString -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover + from piccolo.columns.base import Column from piccolo.table import Table -@dataclass -class Insert(Query): - __slots__ = ("add_delegate",) +class Insert( + Generic[TableInstance], Query[TableInstance, list[dict[str, Any]]] +): + __slots__ = ("add_delegate", "on_conflict_delegate", "returning_delegate") - def __init__(self, table: t.Type[Table], *instances: Table, **kwargs): + def __init__( + self, table: type[TableInstance], *instances: TableInstance, **kwargs + ): super().__init__(table, **kwargs) self.add_delegate = AddDelegate() + self.returning_delegate = ReturningDelegate() + self.on_conflict_delegate = OnConflictDelegate() self.add(*instances) - def add(self, *instances: Table) -> Insert: + ########################################################################### + # Clauses + + def add(self: Self, *instances: Table) -> Self: self.add_delegate.add(*instances, table_class=self.table) return self - def run_callback(self, results): + def returning(self: Self, *columns: Column) -> Self: + self.returning_delegate.returning(columns) + return self + + def on_conflict( + self: Self, + target: Optional[Union[str, Column, tuple[Column, ...]]] = None, + action: Union[ + OnConflictAction, Literal["DO NOTHING", "DO UPDATE"] + ] = OnConflictAction.do_nothing, + values: Optional[Sequence[Union[Column, tuple[Column, Any]]]] = None, + where: Optional[Combinable] = None, + ) -> Self: + if ( + self.engine_type == "sqlite" + and self.table._meta.db.get_version_sync() < 3.24 + ): + raise NotImplementedError( + "SQLite versions lower than 3.24 don't support ON CONFLICT" + ) + + if ( + self.engine_type in ("postgres", "cockroach") + and len(self.on_conflict_delegate._on_conflict.on_conflict_items) + == 1 + ): + raise NotImplementedError( + "Postgres and Cockroach only support a single ON CONFLICT " + "clause." + ) + + self.on_conflict_delegate.on_conflict( + target=target, + action=action, + values=values, + where=where, + ) + return self + + ########################################################################### + + def _raw_response_callback(self, results: list): """ Assign the ids of the created rows to the model instances. """ @@ -33,40 +97,56 @@ def run_callback(self, results): setattr( table_instance, self.table._meta.primary_key._meta.name, - row[self.table._meta.primary_key._meta.name], + row.get( + self.table._meta.primary_key._meta.db_column_name, None + ), ) table_instance._exists_in_db = True @property - def sqlite_querystrings(self) -> t.Sequence[QueryString]: - base = f"INSERT INTO {self.table._meta.tablename}" - columns = ",".join([i._meta.name for i in self.table._meta.columns]) - values = ",".join(["{}" for _ in self.add_delegate._add]) - query = f"{base} ({columns}) VALUES {values}" - return [ - QueryString( - query, - *[i.querystring for i in self.add_delegate._add], - query_type="insert", - table=self.table, - ) - ] - - @property - def postgres_querystrings(self) -> t.Sequence[QueryString]: - base = f"INSERT INTO {self.table._meta.tablename}" + def default_querystrings(self) -> Sequence[QueryString]: + base = f"INSERT INTO {self.table._meta.get_formatted_tablename()}" columns = ",".join( - [f'"{i._meta.name}"' for i in self.table._meta.columns] + f'"{i._meta.db_column_name}"' for i in self.table._meta.columns ) - values = ",".join(["{}" for i in self.add_delegate._add]) - primary_key_name = self.table._meta.primary_key._meta.name - query = ( - f"{base} ({columns}) VALUES {values} RETURNING {primary_key_name}" + values = ",".join("{}" for _ in self.add_delegate._add) + query = f"{base} ({columns}) VALUES {values}" + querystring = QueryString( + query, + *[i.querystring for i in self.add_delegate._add], + query_type="insert", + table=self.table, ) - return [ - QueryString( - query, - *[i.querystring for i in self.add_delegate._add], + + engine_type = self.engine_type + + on_conflict = self.on_conflict_delegate._on_conflict + if on_conflict.on_conflict_items: + querystring = QueryString( + "{}{}", + querystring, + on_conflict.querystring, query_type="insert", + table=self.table, ) - ] + + if engine_type in ("postgres", "cockroach") or ( + engine_type == "sqlite" + and self.table._meta.db.get_version_sync() >= 3.35 + ): + returning = self.returning_delegate._returning + if returning: + return [ + QueryString( + "{}{}", + querystring, + returning.querystring, + query_type="insert", + table=self.table, + ) + ] + + return [querystring] + + +Self = TypeVar("Self", bound=Insert) diff --git a/piccolo/query/methods/objects.py b/piccolo/query/methods/objects.py index aa7a3692e..1bc0d711f 100644 --- a/piccolo/query/methods/objects.py +++ b/piccolo/query/methods/objects.py @@ -1,104 +1,519 @@ from __future__ import annotations -import typing as t -from dataclasses import dataclass +from collections.abc import Callable, Generator, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Literal, + Optional, + TypeVar, + Union, + cast, +) -from piccolo.custom_types import Combinable -from piccolo.engine.base import Batch +from piccolo.columns.column_types import ForeignKey, ReferencedTable +from piccolo.columns.combination import And, Where +from piccolo.custom_types import Combinable, TableInstance +from piccolo.engine.base import BaseBatch from piccolo.query.base import Query +from piccolo.query.methods.select import Select from piccolo.query.mixins import ( + AsOfDelegate, + CallbackDelegate, + CallbackType, LimitDelegate, + LockRowsDelegate, + LockStrength, OffsetDelegate, OrderByDelegate, + OrderByRaw, OutputDelegate, + PrefetchDelegate, WhereDelegate, ) +from piccolo.query.proxy import Proxy from piccolo.querystring import QueryString +from piccolo.utils.dictionary import make_nested +from piccolo.utils.sync import run_sync -from .select import Select - -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from piccolo.columns import Column from piccolo.table import Table -@dataclass -class Objects(Query): +############################################################################### + + +class GetOrCreate( + Proxy["Objects[TableInstance]", TableInstance], Generic[TableInstance] +): + def __init__( + self, + query: Objects[TableInstance], + table_class: type[TableInstance], + where: Combinable, + defaults: dict[Column, Any], + ): + self.query = query + self.table_class = table_class + self.where = where + self.defaults = defaults + + async def run( + self, node: Optional[str] = None, in_pool: bool = True + ) -> TableInstance: + """ + :raises ValueError: + If more than one matching row is found. + + """ + instance = await self.query.get(self.where).run( + node=node, in_pool=in_pool + ) + if instance: + instance._was_created = False + return instance + + data = {**self.defaults} + + # If it's a complex `where`, there can be several column values to + # extract e.g. (Band.name == 'Pythonistas') & (Band.popularity == 1000) + if isinstance(self.where, Where): + data[self.where.column] = self.where.value + elif isinstance(self.where, And): + for column, value in self.where.get_column_values().items(): + if len(column._meta.call_chain) == 0: + # Make sure we only set the value if the column belongs + # to this table. + data[column] = value + + instance = self.table_class(_data=data) + + await instance.save().run(node=node, in_pool=in_pool) + + # If the user wants us to prefetch related objects, for example: + # + # await Band.objects(Band.manager).get_or_create( + # (Band.name == 'Pythonistas') & (Band.manager == 1) + # ) + # + # Then we need to fetch the related objects. + # See https://github.com/piccolo-orm/piccolo/issues/597 + prefetch = self.query.prefetch_delegate.fk_columns + if prefetch: + table = instance.__class__ + primary_key = table._meta.primary_key + instance = ( + await table.objects(*prefetch) + .get(primary_key == getattr(instance, primary_key._meta.name)) + .run() + ) + + instance = cast(TableInstance, instance) + instance._was_created = True + return instance + + +class Get( + Proxy["First[TableInstance]", Optional[TableInstance]], + Generic[TableInstance], +): + pass + + +class First( + Proxy["Objects[TableInstance]", Optional[TableInstance]], + Generic[TableInstance], +): + async def run( + self, node: Optional[str] = None, in_pool: bool = True + ) -> Optional[TableInstance]: + objects = await self.query.run( + node=node, in_pool=in_pool, use_callbacks=False + ) + + results = objects[0] if objects else None + + modified_response: Optional[TableInstance] = ( + await self.query.callback_delegate.invoke( + results=results, kind=CallbackType.success + ) + ) + return modified_response + + +class Create(Generic[TableInstance]): + """ + This is provided as a simple convenience. Rather than running:: + + band = Band(name='Pythonistas') + await band.save() + + We can instead do it in a single line:: + + band = Band.objects().create(name='Pythonistas') + + """ + + def __init__( + self, + table_class: type[TableInstance], + columns: dict[str, Any], + ): + self.table_class = table_class + self.columns = columns + + async def run( + self, + node: Optional[str] = None, + in_pool: bool = True, + ) -> TableInstance: + instance = self.table_class(**self.columns) + await instance.save().run(node=node, in_pool=in_pool) + return instance + + def __await__(self) -> Generator[None, None, TableInstance]: + """ + If the user doesn't explicity call .run(), proxy to it as a + convenience. + """ + return self.run().__await__() + + def run_sync(self, *args, **kwargs) -> TableInstance: + return run_sync(self.run(*args, **kwargs)) + + +class UpdateSelf: + + def __init__( + self, + row: Table, + values: dict[Union[Column, str], Any], + ): + self.row = row + self.values = values + + async def run( + self, + node: Optional[str] = None, + in_pool: bool = True, + ) -> None: + if not self.row._exists_in_db: + raise ValueError("This row doesn't exist in the database.") + + TableClass = self.row.__class__ + + primary_key = TableClass._meta.primary_key + primary_key_value = getattr(self.row, primary_key._meta.name) + + if primary_key_value is None: + raise ValueError("The primary key is None") + + columns = [ + TableClass._meta.get_column_by_name(i) if isinstance(i, str) else i + for i in self.values.keys() + ] + + response = ( + await TableClass.update(self.values) + .where(primary_key == primary_key_value) + .returning(*columns) + .run( + node=node, + in_pool=in_pool, + ) + ) + + for key, value in response[0].items(): + setattr(self.row, key, value) + + def __await__(self) -> Generator[None, None, None]: + """ + If the user doesn't explicity call .run(), proxy to it as a + convenience. + """ + return self.run().__await__() + + def run_sync(self, *args, **kwargs) -> None: + return run_sync(self.run(*args, **kwargs)) + + +class GetRelated(Generic[ReferencedTable]): + + def __init__(self, row: Table, foreign_key: ForeignKey[ReferencedTable]): + self.row = row + self.foreign_key = foreign_key + + async def run( + self, + node: Optional[str] = None, + in_pool: bool = True, + ) -> Optional[ReferencedTable]: + if not self.row._exists_in_db: + raise ValueError("The object doesn't exist in the database.") + + root_table = self.row.__class__ + + data = ( + await root_table.select( + *[ + i.as_alias(i._meta.name) + for i in self.foreign_key.all_columns() + ] + ) + .where( + root_table._meta.primary_key + == getattr(self.row, root_table._meta.primary_key._meta.name) + ) + .first() + .run(node=node, in_pool=in_pool) + ) + + # Make sure that some values were returned: + if data is None or not any(data.values()): + return None + + references = cast( + type[ReferencedTable], + self.foreign_key._foreign_key_meta.resolved_references, + ) + + referenced_object = references(**data) + referenced_object._exists_in_db = True + return referenced_object + + def __await__( + self, + ) -> Generator[None, None, Optional[ReferencedTable]]: + """ + If the user doesn't explicity call .run(), proxy to it as a + convenience. + """ + return self.run().__await__() + + def run_sync(self, *args, **kwargs) -> Optional[ReferencedTable]: + return run_sync(self.run(*args, **kwargs)) + + +############################################################################### + + +class Objects( + Query[TableInstance, list[TableInstance]], Generic[TableInstance] +): """ Almost identical to select, except you have to select all fields, and table instances are returned, rather than just data. """ __slots__ = ( + "nested", + "as_of_delegate", "limit_delegate", "offset_delegate", "order_by_delegate", "output_delegate", + "callback_delegate", + "prefetch_delegate", "where_delegate", + "lock_rows_delegate", ) - def __init__(self, table: t.Type[Table], **kwargs): + def __init__( + self, + table: type[TableInstance], + prefetch: Sequence[Union[ForeignKey, list[ForeignKey]]] = (), + **kwargs, + ): super().__init__(table, **kwargs) + self.as_of_delegate = AsOfDelegate() self.limit_delegate = LimitDelegate() self.offset_delegate = OffsetDelegate() self.order_by_delegate = OrderByDelegate() self.output_delegate = OutputDelegate() self.output_delegate._output.as_objects = True + self.callback_delegate = CallbackDelegate() + self.prefetch_delegate = PrefetchDelegate() + self.prefetch(*prefetch) self.where_delegate = WhereDelegate() + self.lock_rows_delegate = LockRowsDelegate() - def output(self, load_json: bool = False) -> Objects: + def output(self: Self, load_json: bool = False) -> Self: self.output_delegate.output( as_list=False, as_json=False, load_json=load_json ) return self - def limit(self, number: int) -> Objects: + def callback( + self: Self, + callbacks: Union[Callable, list[Callable]], + *, + on: CallbackType = CallbackType.success, + ) -> Self: + self.callback_delegate.callback(callbacks, on=on) + return self + + def as_of(self, interval: str = "-1s") -> Objects: + if self.engine_type != "cockroach": + raise NotImplementedError("Only CockroachDB supports AS OF") + self.as_of_delegate.as_of(interval) + return self + + def limit(self: Self, number: int) -> Self: self.limit_delegate.limit(number) return self - def first(self) -> Objects: - self.limit_delegate.first() + def prefetch( + self: Self, *fk_columns: Union[ForeignKey, list[ForeignKey]] + ) -> Self: + self.prefetch_delegate.prefetch(*fk_columns) return self - def offset(self, number: int) -> Objects: + def offset(self: Self, number: int) -> Self: self.offset_delegate.offset(number) return self - def order_by(self, *columns: Column, ascending=True) -> Objects: - self.order_by_delegate.order_by(*columns, ascending=ascending) + def order_by( + self: Self, *columns: Union[Column, str, OrderByRaw], ascending=True + ) -> Self: + _columns: list[Union[Column, OrderByRaw]] = [] + for column in columns: + if isinstance(column, str): + _columns.append(self.table._meta.get_column_by_name(column)) + else: + _columns.append(column) + + self.order_by_delegate.order_by(*_columns, ascending=ascending) + return self + + def where(self: Self, *where: Union[Combinable, QueryString]) -> Self: + self.where_delegate.where(*where) return self - def where(self, where: Combinable) -> Objects: - self.where_delegate.where(where) + ########################################################################### + + def first(self) -> First[TableInstance]: + self.limit_delegate.limit(1) + return First[TableInstance](query=self) + + def lock_rows( + self: Self, + lock_strength: Union[ + LockStrength, + Literal[ + "UPDATE", + "NO KEY UPDATE", + "KEY SHARE", + "SHARE", + ], + ] = LockStrength.update, + nowait: bool = False, + skip_locked: bool = False, + of: tuple[type[Table], ...] = (), + ) -> Self: + self.lock_rows_delegate.lock_rows( + lock_strength, nowait, skip_locked, of + ) return self + def get(self, where: Combinable) -> Get[TableInstance]: + self.where_delegate.where(where) + self.limit_delegate.limit(1) + return Get[TableInstance](query=First[TableInstance](query=self)) + + def get_or_create( + self, + where: Combinable, + defaults: Optional[dict[Column, Any]] = None, + ) -> GetOrCreate[TableInstance]: + if defaults is None: + defaults = {} + return GetOrCreate[TableInstance]( + query=self, table_class=self.table, where=where, defaults=defaults + ) + + def create(self, **columns: Any) -> Create[TableInstance]: + return Create[TableInstance](table_class=self.table, columns=columns) + + ########################################################################### + async def batch( - self, batch_size: t.Optional[int] = None, **kwargs - ) -> Batch: + self, + batch_size: Optional[int] = None, + node: Optional[str] = None, + **kwargs, + ) -> BaseBatch: if batch_size: kwargs.update(batch_size=batch_size) + if node: + kwargs.update(node=node) return await self.table._meta.db.batch(self, **kwargs) async def response_handler(self, response): - if self.limit_delegate._first: - if len(response) == 0: - return None - else: - return response[0] + if self.output_delegate._output.nested: + return [make_nested(i) for i in response] else: return response @property - def default_querystrings(self) -> t.Sequence[QueryString]: + def default_querystrings(self) -> Sequence[QueryString]: select = Select(table=self.table) for attr in ( + "as_of_delegate", "limit_delegate", "where_delegate", "offset_delegate", "output_delegate", "order_by_delegate", + "lock_rows_delegate", ): setattr(select, attr, getattr(self, attr)) + if self.prefetch_delegate.fk_columns: + select.columns(*self.table.all_columns()) + for fk in self.prefetch_delegate.fk_columns: + if isinstance(fk, ForeignKey): + select.columns(*fk.all_columns()) + else: + raise ValueError(f"{fk} doesn't seem to be a ForeignKey.") + + # Make sure that all intermediate objects are fully loaded. + for parent_fk in fk._meta.call_chain: + select.columns(*parent_fk.all_columns()) + + select.output_delegate.output(nested=True) + return select.querystrings + + ########################################################################### + + async def run( + self, + node: Optional[str] = None, + in_pool: bool = True, + use_callbacks: bool = True, + ) -> list[TableInstance]: + results = await super().run(node=node, in_pool=in_pool) + + if use_callbacks: + # With callbacks, the user can return any data that they want. + # Assume that most of the time they will still return a list of + # Table instances. + modified: list[TableInstance] = ( + await self.callback_delegate.invoke( + results, kind=CallbackType.success + ) + ) + return modified + else: + return results + + def __await__( + self, + ) -> Generator[None, None, list[TableInstance]]: + return super().__await__() + + +Self = TypeVar("Self", bound=Objects) diff --git a/piccolo/query/methods/raw.py b/piccolo/query/methods/raw.py index 2f6176b18..9c35ba53d 100644 --- a/piccolo/query/methods/raw.py +++ b/piccolo/query/methods/raw.py @@ -1,28 +1,40 @@ from __future__ import annotations -import typing as t -from dataclasses import dataclass +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional +from piccolo.engine.base import BaseBatch from piccolo.query.base import Query from piccolo.querystring import QueryString -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from piccolo.table import Table -@dataclass class Raw(Query): __slots__ = ("querystring",) def __init__( self, - table: t.Type[Table], + table: type[Table], querystring: QueryString = QueryString(""), **kwargs, ): super().__init__(table, **kwargs) self.querystring = querystring + async def batch( + self, + batch_size: Optional[int] = None, + node: Optional[str] = None, + **kwargs, + ) -> BaseBatch: + if batch_size: + kwargs.update(batch_size=batch_size) + if node: + kwargs.update(node=node) + return await self.table._meta.db.batch(self, **kwargs) + @property - def default_querystrings(self) -> t.Sequence[QueryString]: + def default_querystrings(self) -> Sequence[QueryString]: return [self.querystring] diff --git a/piccolo/query/methods/refresh.py b/piccolo/query/methods/refresh.py new file mode 100644 index 000000000..792e14df3 --- /dev/null +++ b/piccolo/query/methods/refresh.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional + +from piccolo.utils.encoding import JSONDict +from piccolo.utils.sync import run_sync + +if TYPE_CHECKING: # pragma: no cover + from piccolo.columns import Column + from piccolo.table import Table + + +class Refresh: + """ + Used to refresh :class:`Table ` instances with the + latest data data from the database. Accessible via + :meth:`refresh `. + + :param instance: + The instance to refresh. + :param columns: + Which columns to refresh - it not specified, then all columns are + refreshed. + :param load_json: + Whether to load ``JSON`` / ``JSONB`` columns as objects, instead of + just a string. + + """ + + def __init__( + self, + instance: Table, + columns: Optional[Sequence[Column]] = None, + load_json: bool = False, + ): + self.instance = instance + + if columns: + for column in columns: + if len(column._meta.call_chain) > 0: + raise ValueError( + "We can't currently selectively refresh certain " + "columns on child objects (e.g. Concert.band_1.name). " + "Please just specify top level columns (e.g. " + "Concert.band_1), and the entire child object will be " + "refreshed." + ) + + self.columns = columns + self.load_json = load_json + + @property + def _columns(self) -> Sequence[Column]: + """ + Works out which columns the user wants to refresh. + """ + if self.columns: + return self.columns + + return [ + i for i in self.instance._meta.columns if not i._meta.primary_key + ] + + def _get_columns(self, instance: Table, columns: Sequence[Column]): + """ + If `prefetch` was used on the object, for example:: + + >>> await Band.objects(Band.manager) + + We should also update the prefetched object. + + It works multiple level deep. If we refresh this:: + + >>> await Album.objects(Album.band.manager).first() + + It will update the nested `band` object, and also the `manager` + object. + + """ + from piccolo.columns.column_types import ForeignKey + from piccolo.table import Table + + select_columns = [] + + for column in columns: + if isinstance(column, ForeignKey) and isinstance( + (child_instance := getattr(instance, column._meta.name)), + Table, + ): + select_columns.extend( + self._get_columns( + child_instance, + # Fetch all columns (even the primary key, just in + # case the foreign key now references a different row). + column.all_columns(), + ) + ) + else: + select_columns.append(column) + + return select_columns + + def _update_instance(self, instance: Table, data_dict: dict): + """ + Update the table instance. It is called recursively, if the instance + has child instances. + """ + for key, value in data_dict.items(): + if isinstance(value, dict) and not isinstance(value, JSONDict): + # If the value is a dict, then it's a child instance. + if all(i is None for i in value.values()): + # If all values in the nested object are None, then we can + # safely assume that the object itself is null, as the + # primary key value must be null. + setattr(instance, key, None) + else: + self._update_instance(getattr(instance, key), value) + else: + setattr(instance, key, value) + + async def run( + self, in_pool: bool = True, node: Optional[str] = None + ) -> Table: + """ + Run it asynchronously. For example:: + + await my_instance.refresh().run() + + # or for convenience: + await my_instance.refresh() + + Modifies the instance in place, but also returns it as a convenience. + + """ + instance = self.instance + + if not instance._exists_in_db: + raise ValueError("The instance doesn't exist in the database.") + + pk_column = instance._meta.primary_key + + primary_key_value = getattr(instance, pk_column._meta.name, None) + + if primary_key_value is None: + raise ValueError("The instance's primary key value isn't defined.") + + columns = self._columns + if not columns: + raise ValueError("No columns to fetch.") + + select_columns = self._get_columns( + instance=self.instance, columns=columns + ) + + data_dict = ( + await instance.__class__.select(*select_columns) + .where(pk_column == primary_key_value) + .output(nested=True, load_json=self.load_json) + .first() + .run(node=node, in_pool=in_pool) + ) + + if data_dict is None: + raise ValueError( + "The object doesn't exist in the database any more." + ) + + self._update_instance(instance=instance, data_dict=data_dict) + + return instance + + def __await__(self): + """ + If the user doesn't explicity call :meth:`run`, proxy to it as a + convenience. + """ + return self.run().__await__() + + def run_sync(self, *args, **kwargs) -> Table: + """ + Run it synchronously. For example:: + + my_instance.refresh().run_sync() + + """ + return run_sync(self.run(*args, **kwargs)) diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index 98fe620ce..4ba3a2977 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -1,146 +1,157 @@ from __future__ import annotations -import decimal -import typing as t +import itertools from collections import OrderedDict +from collections.abc import Callable, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Literal, + Optional, + TypeVar, + Union, + overload, +) from piccolo.columns import Column, Selectable +from piccolo.columns.column_types import JSON, JSONB +from piccolo.columns.m2m import M2MSelect from piccolo.columns.readable import Readable -from piccolo.engine.base import Batch +from piccolo.custom_types import TableInstance +from piccolo.engine.base import BaseBatch from piccolo.query.base import Query from piccolo.query.mixins import ( + AsOfDelegate, + CallbackDelegate, + CallbackType, ColumnsDelegate, DistinctDelegate, GroupByDelegate, LimitDelegate, + LockRowsDelegate, + LockStrength, OffsetDelegate, OrderByDelegate, + OrderByRaw, OutputDelegate, WhereDelegate, ) +from piccolo.query.proxy import Proxy from piccolo.querystring import QueryString +from piccolo.utils.dictionary import make_nested +from piccolo.utils.encoding import dump_json, load_json +from piccolo.utils.warnings import colored_warning -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from piccolo.custom_types import Combinable from piccolo.table import Table # noqa +# Here to avoid breaking changes - will be removed in the future. +from piccolo.query.functions.aggregate import ( # noqa: F401 + Avg, + Count, + Max, + Min, + Sum, +) -def is_numeric_column(column: Column) -> bool: - return column.value_type in (int, decimal.Decimal, float) +class SelectRaw(Selectable): + def __init__(self, sql: str, *args: Any) -> None: + """ + Execute raw SQL in your select query. -class Avg(Selectable): - """ - AVG() SQL function. Column type must be numeric to run the query. + .. code-block:: python - await Band.select(Avg(Band.popularity)).run() or with aliases - await Band.select(Avg(Band.popularity, alias="popularity_avg")).run() - await Band.select(Avg(Band.popularity).as_alias("popularity_avg")).run() - """ + >>> await Band.select( + ... Band.name, + ... SelectRaw("log(popularity) AS log_popularity") + ... ) + [{'name': 'Pythonistas', 'log_popularity': 3.0}] - def __init__(self, column: Column, alias: str = "avg"): - if is_numeric_column(column): - self.column = column - else: - raise ValueError("Column type must be numeric to run the query.") - self.alias = alias + """ + self.querystring = QueryString(sql, *args) - def get_select_string(self, engine_type: str, just_alias=False) -> str: - column_name = self.column._meta.get_full_name(just_alias=just_alias) - return f"AVG({column_name}) AS {self.alias}" + def get_select_string( + self, engine_type: str, with_alias: bool = True + ) -> QueryString: + return self.querystring -class Count(Selectable): - """ - Used in conjunction with the ``group_by`` clause in ``Select`` queries. +OptionalDict = Optional[dict[str, Any]] - If a column is specified, the count is for non-null values in that - column. If no column is specified, the count is for all rows, whether - they have null values or not. - Band.select(Band.name, Count()).group_by(Band.name).run() - Band.select(Band.name, Count(alias="total")).group_by(Band.name).run() - Band.select(Band.name, Count().as_alias("total")).group_by(Band.name).run() +class First(Proxy["Select", OptionalDict]): """ - - def __init__( - self, column: t.Optional[Column] = None, alias: str = "count" - ): - self.column = column - self.alias = alias - - def get_select_string(self, engine_type: str, just_alias=False) -> str: - if self.column is None: - column_name = "*" - else: - column_name = self.column._meta.get_full_name( - just_alias=just_alias - ) - return f"COUNT({column_name}) AS {self.alias}" - - -class Max(Selectable): + This is for static typing purposes. """ - MAX() SQL function. - await Band.select(Max(Band.popularity)).run() or with aliases - await Band.select(Max(Band.popularity, alias="popularity_max")).run() - await Band.select(Max(Band.popularity).as_alias("popularity_max")).run() - """ + def __init__(self, query: Select): + self.query = query - def __init__(self, column: Column, alias: str = "max"): - self.column = column - self.alias = alias + async def run( + self, + node: Optional[str] = None, + in_pool: bool = True, + ) -> OptionalDict: + rows = await self.query.run( + node=node, in_pool=in_pool, use_callbacks=False + ) + results = rows[0] if rows else None - def get_select_string(self, engine_type: str, just_alias=False) -> str: - column_name = self.column._meta.get_full_name(just_alias=just_alias) - return f"MAX({column_name}) AS {self.alias}" + modified_response = await self.query.callback_delegate.invoke( + results=results, kind=CallbackType.success + ) + return modified_response -class Min(Selectable): +class SelectList(Proxy["Select", list]): """ - MIN() SQL function. - - await Band.select(Min(Band.popularity)).run() - await Band.select(Min(Band.popularity, alias="popularity_min")).run() - await Band.select(Min(Band.popularity).as_alias("popularity_min")).run() + This is for static typing purposes. """ - def __init__(self, column: Column, alias: str = "min"): - self.column = column - self.alias = alias - - def get_select_string(self, engine_type: str, just_alias=False) -> str: - column_name = self.column._meta.get_full_name(just_alias=just_alias) - return f"MIN({column_name}) AS {self.alias}" + async def run( + self, + node: Optional[str] = None, + in_pool: bool = True, + ) -> list: + rows = await self.query.run( + node=node, in_pool=in_pool, use_callbacks=False + ) + if len(rows) == 0: + response = [] + else: + if len(rows[0].keys()) != 1: + raise ValueError("Each row returned more than one value") -class Sum(Selectable): - """ - SUM() SQL function. Column type must be numeric to run the query. + response = list(itertools.chain(*[j.values() for j in rows])) - await Band.select(Sum(Band.popularity)).run() - await Band.select(Sum(Band.popularity, alias="popularity_sum")).run() - await Band.select(Sum(Band.popularity).as_alias("popularity_sum")).run() - """ + modified_response = await self.query.callback_delegate.invoke( + results=response, kind=CallbackType.success + ) + return modified_response - def __init__(self, column: Column, alias: str = "sum"): - if is_numeric_column(column): - self.column = column - else: - raise ValueError("Column type must be numeric to run the query.") - self.alias = alias - def get_select_string(self, engine_type: str, just_alias=False) -> str: - column_name = self.column._meta.get_full_name(just_alias=just_alias) - return f"SUM({column_name}) AS {self.alias}" +class SelectJSON(Proxy["Select", str]): + """ + This is for static typing purposes. + """ + async def run( + self, + node: Optional[str] = None, + in_pool: bool = True, + ) -> str: + rows = await self.query.run(node=node, in_pool=in_pool) + return dump_json(rows) -class Select(Query): +class Select(Query[TableInstance, list[dict[str, Any]]]): __slots__ = ( "columns_list", "exclude_secrets", + "as_of_delegate", "columns_delegate", "distinct_delegate", "group_by_delegate", @@ -148,19 +159,24 @@ class Select(Query): "offset_delegate", "order_by_delegate", "output_delegate", + "callback_delegate", "where_delegate", + "lock_rows_delegate", ) def __init__( self, - table: t.Type[Table], - columns_list: t.Sequence[t.Union[Selectable, str]] = [], + table: type[TableInstance], + columns_list: Optional[Sequence[Union[Selectable, str]]] = None, exclude_secrets: bool = False, **kwargs, ): + if columns_list is None: + columns_list = [] super().__init__(table, **kwargs) self.exclude_secrets = exclude_secrets + self.as_of_delegate = AsOfDelegate() self.columns_delegate = ColumnsDelegate() self.distinct_delegate = DistinctDelegate() self.group_by_delegate = GroupByDelegate() @@ -168,21 +184,26 @@ def __init__( self.offset_delegate = OffsetDelegate() self.order_by_delegate = OrderByDelegate() self.output_delegate = OutputDelegate() + self.callback_delegate = CallbackDelegate() self.where_delegate = WhereDelegate() + self.lock_rows_delegate = LockRowsDelegate() self.columns(*columns_list) - def columns(self, *columns: t.Union[Selectable, str]) -> Select: + def columns(self: Self, *columns: Union[Selectable, str]) -> Self: _columns = self.table._process_column_args(*columns) self.columns_delegate.columns(*_columns) return self - def distinct(self) -> Select: - self.distinct_delegate.distinct() + def distinct(self: Self, *, on: Optional[Sequence[Column]] = None) -> Self: + if on is not None and self.engine_type == "sqlite": + raise NotImplementedError("SQLite doesn't support DISTINCT ON") + + self.distinct_delegate.distinct(enabled=True, on=on) return self - def group_by(self, *columns: Column) -> Select: - _columns: t.List[Column] = [ + def group_by(self: Self, *columns: Union[Column, str]) -> Self: + _columns: list[Column] = [ i for i in self.table._process_column_args(*columns) if isinstance(i, Column) @@ -190,68 +211,294 @@ def group_by(self, *columns: Column) -> Select: self.group_by_delegate.group_by(*_columns) return self - def limit(self, number: int) -> Select: - self.limit_delegate.limit(number) + def as_of(self: Self, interval: str = "-1s") -> Self: + if self.engine_type != "cockroach": + raise NotImplementedError("Only CockroachDB supports AS OF") + + self.as_of_delegate.as_of(interval) return self - def first(self) -> Select: - self.limit_delegate.first() + def limit(self: Self, number: int) -> Self: + self.limit_delegate.limit(number) return self - def offset(self, number: int) -> Select: + def first(self) -> First: + self.limit_delegate.limit(1) + return First(query=self) + + def offset(self: Self, number: int) -> Self: self.offset_delegate.offset(number) return self - async def response_handler(self, response): - if self.limit_delegate._first: - if len(response) == 0: - return None - else: - return response[0] + def lock_rows( + self: Self, + lock_strength: Union[ + LockStrength, + Literal[ + "UPDATE", + "NO KEY UPDATE", + "KEY SHARE", + "SHARE", + ], + ] = LockStrength.update, + nowait: bool = False, + skip_locked: bool = False, + of: tuple[type[Table], ...] = (), + ) -> Self: + self.lock_rows_delegate.lock_rows( + lock_strength, nowait, skip_locked, of + ) + return self + + async def _splice_m2m_rows( + self, + response: list[dict[str, Any]], + secondary_table: type[Table], + secondary_table_pk: Column, + m2m_name: str, + m2m_select: M2MSelect, + as_list: bool = False, + ): + row_ids = list( + set(itertools.chain(*[row[m2m_name] for row in response])) + ) + extra_rows = ( + ( + await secondary_table.select( + *m2m_select.columns, + secondary_table_pk.as_alias("mapping_key"), + ) + .where(secondary_table_pk.is_in(row_ids)) + .output(load_json=m2m_select.load_json) + .run() + ) + if row_ids + else [] + ) + if as_list: + column_name = m2m_select.columns[0]._meta.name + extra_rows_map = { + row["mapping_key"]: row[column_name] for row in extra_rows + } else: - return response + extra_rows_map = { + row["mapping_key"]: { + key: value + for key, value in row.items() + if key != "mapping_key" + } + for row in extra_rows + } + for row in response: + row[m2m_name] = [extra_rows_map.get(i) for i in row[m2m_name]] + return response - def order_by(self, *columns: Column, ascending=True) -> Select: - _columns: t.List[Column] = [ + async def response_handler(self, response): + m2m_selects = [ i - for i in self.table._process_column_args(*columns) - if isinstance(i, Column) + for i in self.columns_delegate.selected_columns + if isinstance(i, M2MSelect) ] + for m2m_select in m2m_selects: + m2m_name = m2m_select.m2m._meta.name + secondary_table = m2m_select.m2m._meta.secondary_table + secondary_table_pk = secondary_table._meta.primary_key + + if self.engine_type == "sqlite": + # With M2M queries in SQLite, we always get the value back as a + # list of strings, so we need to do some type conversion. + value_type = ( + m2m_select.columns[0].__class__.value_type + if m2m_select.as_list and m2m_select.serialisation_safe + else secondary_table_pk.value_type + ) + try: + for row in response: + data = row[m2m_name] + row[m2m_name] = ( + [value_type(i) for i in row[m2m_name]] + if data + else [] + ) + except ValueError: + colored_warning( + "Unable to do type conversion for the " + f"{m2m_name} relation" + ) + + # If the user requested a single column, we just return that + # from the database. Otherwise we request the primary key + # value, so we can fetch the rest of the data in a subsequent + # SQL query - see below. + if m2m_select.as_list: + if m2m_select.serialisation_safe: + pass + else: + response = await self._splice_m2m_rows( + response, + secondary_table, + secondary_table_pk, + m2m_name, + m2m_select, + as_list=True, + ) + else: + if ( + len(m2m_select.columns) == 1 + and m2m_select.serialisation_safe + ): + column_name = m2m_select.columns[0]._meta.name + for row in response: + row[m2m_name] = [ + {column_name: i} for i in row[m2m_name] + ] + else: + response = await self._splice_m2m_rows( + response, + secondary_table, + secondary_table_pk, + m2m_name, + m2m_select, + ) + + elif self.engine_type in ("postgres", "cockroach"): + if m2m_select.as_list: + # We get the data back as an array, and can just return it + # unless it's JSON. + if ( + type(m2m_select.columns[0]) in (JSON, JSONB) + and m2m_select.load_json + ): + for row in response: + data = row[m2m_name] + row[m2m_name] = [load_json(i) for i in data] + elif m2m_select.serialisation_safe: + # If the columns requested can be safely serialised, they + # are returned as a JSON string, so we need to deserialise + # it. + for row in response: + data = row[m2m_name] + row[m2m_name] = load_json(data) if data else [] + else: + # If the data can't be safely serialised as JSON, we get + # back an array of primary key values, and need to + # splice in the correct values using Python. + response = await self._splice_m2m_rows( + response, + secondary_table, + secondary_table_pk, + m2m_name, + m2m_select, + ) + + ####################################################################### + + # If no columns were specified, it's a select *, so we know that + # no columns were selected from related tables. + was_select_star = len(self.columns_delegate.selected_columns) == 0 + + if self.output_delegate._output.nested and not was_select_star: + return [make_nested(i) for i in response] + else: + return response + + def order_by( + self: Self, *columns: Union[Column, str, OrderByRaw], ascending=True + ) -> Self: + """ + :param columns: + Either a :class:`piccolo.columns.base.Column` instance, a string + representing a column name, or :class:`piccolo.query.OrderByRaw` + which allows you for complex use cases like + ``OrderByRaw('random()')``. + """ + _columns: list[Union[Column, OrderByRaw]] = [] + for column in columns: + if isinstance(column, str): + _columns.append(self.table._meta.get_column_by_name(column)) + else: + _columns.append(column) + self.order_by_delegate.order_by(*_columns, ascending=ascending) return self + @overload + def output(self: Self, *, as_list: bool) -> SelectList: # type: ignore + ... + + @overload + def output(self: Self, *, as_json: bool) -> SelectJSON: # type: ignore + ... + + @overload + def output(self: Self, *, load_json: bool) -> Self: ... + + @overload + def output(self: Self, *, load_json: bool, as_list: bool) -> SelectJSON: # type: ignore # noqa: E501 + ... + + @overload + def output(self: Self, *, load_json: bool, nested: bool) -> Self: ... + + @overload + def output(self: Self, *, nested: bool) -> Self: ... + def output( - self, + self: Self, + *, as_list: bool = False, as_json: bool = False, load_json: bool = False, - ) -> Select: + nested: bool = False, + ) -> Union[Self, SelectJSON, SelectList]: self.output_delegate.output( - as_list=as_list, as_json=as_json, load_json=load_json + as_list=as_list, + as_json=as_json, + load_json=load_json, + nested=nested, ) + if as_list: + return SelectList(query=self) + elif as_json: + return SelectJSON(query=self) + return self - def where(self, where: Combinable) -> Select: - self.where_delegate.where(where) + def callback( + self: Self, + callbacks: Union[Callable, list[Callable]], + *, + on: CallbackType = CallbackType.success, + ) -> Self: + self.callback_delegate.callback(callbacks, on=on) + return self + + def where(self: Self, *where: Union[Combinable, QueryString]) -> Self: + self.where_delegate.where(*where) return self async def batch( - self, batch_size: t.Optional[int] = None, **kwargs - ) -> Batch: + self, + batch_size: Optional[int] = None, + node: Optional[str] = None, + **kwargs, + ) -> BaseBatch: if batch_size: kwargs.update(batch_size=batch_size) + if node: + kwargs.update(node=node) return await self.table._meta.db.batch(self, **kwargs) ########################################################################### - def _get_joins(self, columns: t.Sequence[Selectable]) -> t.List[str]: + def _get_joins(self, columns: Sequence[Selectable]) -> list[str]: """ A call chain is a sequence of foreign keys representing joins which need to be made to retrieve a column in another table. """ - joins: t.List[str] = [] + joins: list[str] = [] - readables: t.List[Readable] = [ + readables: list[Readable] = [ i for i in columns if isinstance(i, Readable) ] @@ -259,35 +506,42 @@ def _get_joins(self, columns: t.Sequence[Selectable]) -> t.List[str]: for readable in readables: columns += readable.columns + querystrings: list[QueryString] = [ + i for i in columns if isinstance(i, QueryString) + ] + for querystring in querystrings: + if querystring_columns := getattr(querystring, "columns", []): + columns += querystring_columns + for column in columns: if not isinstance(column, Column): continue - _joins: t.List[str] = [] + _joins: list[str] = [] for index, key in enumerate(column._meta.call_chain, 0): - table_alias = "$".join( - [ - f"{_key._meta.table._meta.tablename}${_key._meta.name}" - for _key in column._meta.call_chain[: index + 1] - ] - ) - key._meta.table_alias = table_alias + table_alias = key.table_alias if index > 0: left_tablename = column._meta.call_chain[ index - 1 - ]._meta.table_alias + ].table_alias else: - left_tablename = key._meta.table._meta.tablename + left_tablename = ( + key._meta.table._meta.get_formatted_tablename() + ) # noqa: E501 right_tablename = ( - key._foreign_key_meta.resolved_references._meta.tablename + key._foreign_key_meta.resolved_references._meta.get_formatted_tablename() # noqa: E501 ) + pk_name = column._meta.call_chain[ + index + ]._foreign_key_meta.resolved_target_column._meta.name + _joins.append( - f"LEFT JOIN {right_tablename} {table_alias}" + f'LEFT JOIN {right_tablename} "{table_alias}"' " ON " - f"({left_tablename}.{key._meta.name} = {table_alias}.id)" + f'({left_tablename}."{key._meta.db_column_name}" = "{table_alias}"."{pk_name}")' # noqa: E501 ) joins.extend(_joins) @@ -295,23 +549,21 @@ def _get_joins(self, columns: t.Sequence[Selectable]) -> t.List[str]: # Remove duplicates return list(OrderedDict.fromkeys(joins)) - def _check_valid_call_chain(self, keys: t.Sequence[Selectable]) -> bool: + def _check_valid_call_chain(self, keys: Sequence[Selectable]) -> bool: for column in keys: if not isinstance(column, Column): continue - if column._meta.call_chain: + if column._meta.call_chain and len(column._meta.call_chain) > 10: # Make sure the call_chain isn't too large to discourage # very inefficient queries. - - if len(column._meta.call_chain) > 10: - raise Exception( - "Joining more than 10 tables isn't supported - " - "please restructure your query." - ) + raise Exception( + "Joining more than 10 tables isn't supported - " + "please restructure your query." + ) return True @property - def default_querystrings(self) -> t.Sequence[QueryString]: + def default_querystrings(self) -> Sequence[QueryString]: # JOIN self._check_valid_call_chain(self.columns_delegate.selected_columns) @@ -322,7 +574,7 @@ def default_querystrings(self) -> t.Sequence[QueryString]: ) # Combine all joins, and remove duplicates - joins: t.List[str] = list( + joins: list[str] = list( OrderedDict.fromkeys(select_joins + where_joins + order_by_joins) ) @@ -339,36 +591,44 @@ def default_querystrings(self) -> t.Sequence[QueryString]: engine_type = self.table._meta.db.engine_type - select_strings: t.List[str] = [ + select_strings: list[QueryString] = [ c.get_select_string(engine_type=engine_type) for c in self.columns_delegate.selected_columns ] - columns_str = ", ".join(select_strings) ####################################################################### - select = ( - "SELECT DISTINCT" if self.distinct_delegate._distinct else "SELECT" - ) - query = f"{select} {columns_str} FROM {self.table._meta.tablename}" + args: list[Any] = [] + + query = "SELECT" + + distinct = self.distinct_delegate._distinct + if distinct.on: + distinct.validate_on(self.order_by_delegate._order_by) + query += "{}" + args.append(distinct.querystring) + + columns_str = ", ".join("{}" for _ in select_strings) + query += f" {columns_str} FROM {self.table._meta.get_formatted_tablename()}" # noqa: E501 + args.extend(select_strings) for join in joins: query += f" {join}" - ####################################################################### - - args: t.List[t.Any] = [] + if self.as_of_delegate._as_of: + query += "{}" + args.append(self.as_of_delegate._as_of.querystring) if self.where_delegate._where: query += " WHERE {}" args.append(self.where_delegate._where.querystring) if self.group_by_delegate._group_by: - query += " {}" + query += "{}" args.append(self.group_by_delegate._group_by.querystring) - if self.order_by_delegate._order_by: - query += " {}" + if self.order_by_delegate._order_by.order_by_items: + query += "{}" args.append(self.order_by_delegate._order_by.querystring) if ( @@ -382,13 +642,41 @@ def default_querystrings(self) -> t.Sequence[QueryString]: ) if self.limit_delegate._limit: - query += " {}" + query += "{}" args.append(self.limit_delegate._limit.querystring) if self.offset_delegate._offset: - query += " {}" + query += "{}" args.append(self.offset_delegate._offset.querystring) + if self.lock_rows_delegate._lock_rows: + if engine_type == "sqlite": + raise NotImplementedError( + "SQLite doesn't support row locking e.g. SELECT ... FOR " + "UPDATE" + ) + + query += "{}" + args.append(self.lock_rows_delegate._lock_rows.querystring) + querystring = QueryString(query, *args) return [querystring] + + async def run( + self, + node: Optional[str] = None, + in_pool: bool = True, + use_callbacks: bool = True, + **kwargs, + ) -> list[dict[str, Any]]: + results = await super().run(node=node, in_pool=in_pool) + if use_callbacks: + return await self.callback_delegate.invoke( + results, kind=CallbackType.success + ) + else: + return results + + +Self = TypeVar("Self", bound=Select) diff --git a/piccolo/query/methods/table_exists.py b/piccolo/query/methods/table_exists.py index 72a399ff5..2d90059cd 100644 --- a/piccolo/query/methods/table_exists.py +++ b/piccolo/query/methods/table_exists.py @@ -1,32 +1,45 @@ from __future__ import annotations -import typing as t +from collections.abc import Sequence +from piccolo.custom_types import TableInstance from piccolo.query.base import Query from piccolo.querystring import QueryString -class TableExists(Query): +class TableExists(Query[TableInstance, bool]): - __slots__: t.Tuple = tuple() + __slots__: tuple = () async def response_handler(self, response): return bool(response[0]["exists"]) @property - def sqlite_querystrings(self) -> t.Sequence[QueryString]: + def sqlite_querystrings(self) -> Sequence[QueryString]: return [ QueryString( "SELECT EXISTS(SELECT * FROM sqlite_master WHERE " - f"name = '{self.table._meta.tablename}') AS 'exists'" + "name = {}) AS 'exists'", + self.table._meta.tablename, ) ] @property - def postgres_querystrings(self) -> t.Sequence[QueryString]: - return [ - QueryString( - "SELECT EXISTS(SELECT * FROM information_schema.tables WHERE " - f"table_name = '{self.table._meta.tablename}')" + def postgres_querystrings(self) -> Sequence[QueryString]: + subquery = QueryString( + "SELECT * FROM information_schema.tables WHERE table_name = {}", + self.table._meta.tablename, + ) + + if self.table._meta.schema: + subquery = QueryString( + "{} AND table_schema = {}", subquery, self.table._meta.schema ) - ] + + query = QueryString("SELECT EXISTS({})", subquery) + + return [query] + + @property + def cockroach_querystrings(self) -> Sequence[QueryString]: + return self.postgres_querystrings diff --git a/piccolo/query/methods/update.py b/piccolo/query/methods/update.py index 013c9d53e..97715b7d9 100644 --- a/piccolo/query/methods/update.py +++ b/piccolo/query/methods/update.py @@ -1,74 +1,118 @@ from __future__ import annotations -import typing as t -from dataclasses import dataclass +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Optional, Union -from piccolo.custom_types import Combinable +from piccolo.custom_types import Combinable, TableInstance from piccolo.query.base import Query -from piccolo.query.mixins import ValuesDelegate, WhereDelegate +from piccolo.query.mixins import ( + ReturningDelegate, + ValuesDelegate, + WhereDelegate, +) from piccolo.querystring import QueryString -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from piccolo.columns import Column - from piccolo.table import Table -@dataclass -class Update(Query): +class UpdateError(Exception): + pass - __slots__ = ("values_delegate", "where_delegate") - def __init__(self, table: t.Type[Table], **kwargs): +class Update(Query[TableInstance, list[Any]]): + __slots__ = ( + "force", + "returning_delegate", + "values_delegate", + "where_delegate", + ) + + def __init__( + self, table: type[TableInstance], force: bool = False, **kwargs + ): super().__init__(table, **kwargs) + self.force = force + self.returning_delegate = ReturningDelegate() self.values_delegate = ValuesDelegate(table=table) self.where_delegate = WhereDelegate() + ########################################################################### + # Clauses + def values( - self, values: t.Dict[t.Union[Column, str], t.Any] = {}, **kwargs + self, + values: Optional[dict[Union[Column, str], Any]] = None, + **kwargs, ) -> Update: + if values is None: + values = {} values = dict(values, **kwargs) self.values_delegate.values(values) return self - def where(self, where: Combinable) -> Update: - self.where_delegate.where(where) + def where(self, *where: Union[Combinable, QueryString]) -> Update: + self.where_delegate.where(*where) + return self + + def returning(self, *columns: Column) -> Update: + self.returning_delegate.returning(columns) return self - def validate(self): + ########################################################################### + + def _validate(self): + """ + Called at the start of :meth:`piccolo.query.base.Query.run` to make + sure the user has configured the query correctly before running it. + """ if len(self.values_delegate._values) == 0: - raise ValueError( - "No values were specified to update - please use .values" - ) + raise ValueError("No values were specified to update.") for column, _ in self.values_delegate._values.items(): if len(column._meta.call_chain) > 0: raise ValueError( - "Related values can't be updated via an update" + "Related values can't be updated via an update." ) - @property - def default_querystrings(self) -> t.Sequence[QueryString]: - self.validate() + if (not self.where_delegate._where) and (not self.force): + classname = self.table.__name__ + raise UpdateError( + "Do you really want to update all rows in " + f"{classname}? If so, use pass `force=True` into " + f"`{classname}.update`. Otherwise, add a where clause." + ) + ########################################################################### + + @property + def default_querystrings(self) -> Sequence[QueryString]: columns_str = ", ".join( - [ - f"{col._meta.name} = {{}}" - for col, _ in self.values_delegate._values.items() - ] + f'"{col._meta.db_column_name}" = {{}}' + for col, _ in self.values_delegate._values.items() ) - query = f"UPDATE {self.table._meta.tablename} SET " + columns_str + query = f"UPDATE {self.table._meta.get_formatted_tablename()} SET {columns_str}" # noqa: E501 querystring = QueryString( query, *self.values_delegate.get_sql_values() ) if self.where_delegate._where: - where_querystring = QueryString( + # The JOIN syntax isn't allowed in SQL UPDATE queries, so we need + # to write the WHERE clause differently, using a sub select. + + querystring = QueryString( "{} WHERE {}", querystring, - self.where_delegate._where.querystring, + self.where_delegate._where.querystring_for_update_and_delete, ) - return [where_querystring] - else: - return [querystring] + + if self.returning_delegate._returning: + querystring = QueryString( + "{}{}", + querystring, + self.returning_delegate._returning.querystring, + ) + + return [querystring] diff --git a/piccolo/query/mixins.py b/piccolo/query/mixins.py index 838f52ff6..178d793bf 100644 --- a/piccolo/query/mixins.py +++ b/piccolo/query/mixins.py @@ -1,18 +1,90 @@ from __future__ import annotations -import typing as t +import asyncio +import collections.abc +import itertools +from collections.abc import Callable, Sequence from dataclasses import dataclass, field +from enum import Enum, auto +from typing import TYPE_CHECKING, Any, Literal, Optional, Union -from piccolo.columns import And, Column, Or, Secret, Where +from piccolo.columns import And, Column, Or, Where +from piccolo.columns.column_types import ForeignKey +from piccolo.columns.combination import WhereRaw from piccolo.custom_types import Combinable from piccolo.querystring import QueryString +from piccolo.utils.list import flatten from piccolo.utils.sql_values import convert_to_sql_value -if t.TYPE_CHECKING: # pragma: no cover - from piccolo.columns.base import Selectable +if TYPE_CHECKING: # pragma: no cover + from piccolo.querystring import Selectable from piccolo.table import Table # noqa +class DistinctOnError(ValueError): + """ + Raised when ``DISTINCT ON`` queries are malformed. + """ + + pass + + +@dataclass +class Distinct: + __slots__ = ("enabled", "on") + + enabled: bool + on: Optional[Sequence[Column]] + + @property + def querystring(self) -> QueryString: + if self.enabled: + if self.on: + column_names = ", ".join( + i._meta.get_full_name(with_alias=False) for i in self.on + ) + return QueryString(f" DISTINCT ON ({column_names})") + else: + return QueryString(" DISTINCT") + else: + return QueryString(" ALL") + + def validate_on(self, order_by: OrderBy): + """ + When using the `on` argument, the first column must match the first + order by column. + + :raises DistinctOnError: + If the columns don't match. + + """ + validated = True + + try: + first_order_column = order_by.order_by_items[0].columns[0] + except IndexError: + validated = False + else: + if not self.on: + validated = False + elif isinstance(first_order_column, Column) and not self.on[ + 0 + ]._equals(first_order_column): + validated = False + + if not validated: + raise DistinctOnError( + "The first `order_by` column must match the first column " + "passed to `on`." + ) + + def __str__(self) -> str: + return self.querystring.__str__() + + def copy(self) -> Distinct: + return self.__class__(enabled=self.enabled, on=self.on) + + @dataclass class Limit: __slots__ = ("number",) @@ -20,7 +92,7 @@ class Limit: number: int def __post_init__(self): - if type(self.number) != int: + if not isinstance(self.number, int): raise TypeError("Limit must be an integer") @property @@ -34,6 +106,24 @@ def copy(self) -> Limit: return self.__class__(number=self.number) +@dataclass +class AsOf: + __slots__ = ("interval",) + + interval: str + + def __post_init__(self): + if not isinstance(self.interval, str): + raise TypeError("As Of must be a string. Example: '-1s'") + + @property + def querystring(self) -> QueryString: + return QueryString(f" AS OF SYSTEM TIME '{self.interval}'") + + def __str__(self) -> str: + return self.querystring.__str__() + + @dataclass class Offset: __slots__ = ("number",) @@ -41,8 +131,8 @@ class Offset: number: int def __post_init__(self): - if type(self.number) != int: - raise TypeError("Limit must be an integer") + if not isinstance(self.number, int): + raise TypeError("Offset must be an integer") @property def querystring(self) -> QueryString: @@ -53,31 +143,76 @@ def __str__(self) -> str: @dataclass -class OrderBy: +class OrderByRaw: + __slots__ = ("sql",) + + sql: str + + +@dataclass +class OrderByItem: __slots__ = ("columns", "ascending") - columns: t.Sequence[Column] + columns: Sequence[Union[Column, OrderByRaw]] ascending: bool + +@dataclass +class OrderBy: + order_by_items: list[OrderByItem] = field(default_factory=list) + @property def querystring(self) -> QueryString: - order = "ASC" if self.ascending else "DESC" - columns_names = ", ".join( - [i._meta.get_full_name(just_alias=True) for i in self.columns] - ) - return QueryString(f" ORDER BY {columns_names} {order}") + order_by_strings: list[str] = [] + for order_by_item in self.order_by_items: + order = "ASC" if order_by_item.ascending else "DESC" + for column in order_by_item.columns: + if isinstance(column, Column): + expression = column._meta.get_full_name(with_alias=False) + elif isinstance(column, OrderByRaw): + expression = column.sql + else: + raise ValueError("Unrecognised order_by") + + order_by_strings.append(f"{expression} {order}") + + return QueryString(f" ORDER BY {', '.join(order_by_strings)}") def __str__(self): return self.querystring.__str__() @dataclass -class Output: +class Returning: + __slots__ = ("columns",) + + columns: list[Column] + @property + def querystring(self) -> QueryString: + column_names = [] + for column in self.columns: + column_names.append( + f'"{column._meta.db_column_name}" AS "{column._alias}"' + if column._alias + else f'"{column._meta.db_column_name}"' + ) + + columns_string = ", ".join(column_names) + + return QueryString(f" RETURNING {columns_string}") + + def __str__(self): + return self.querystring.__str__() + + +@dataclass +class Output: as_json: bool = False as_list: bool = False as_objects: bool = False load_json: bool = False + nested: bool = False def copy(self) -> Output: return self.__class__( @@ -85,14 +220,24 @@ def copy(self) -> Output: as_list=self.as_list, as_objects=self.as_objects, load_json=self.load_json, + nested=self.nested, ) +class CallbackType(Enum): + success = auto() + + @dataclass -class WhereDelegate: +class Callback: + kind: CallbackType + target: Callable - _where: t.Optional[Combinable] = None - _where_columns: t.List[Column] = field(default_factory=list) + +@dataclass +class WhereDelegate: + _where: Optional[Combinable] = None + _where_columns: list[Column] = field(default_factory=list) def get_where_columns(self): """ @@ -100,65 +245,112 @@ def get_where_columns(self): needed. """ self._where_columns = [] - self._extract_columns(self._where) + if self._where is not None: + self._extract_columns(self._where) return self._where_columns def _extract_columns(self, combinable: Combinable): if isinstance(combinable, Where): self._where_columns.append(combinable.column) - elif isinstance(combinable, And) or isinstance(combinable, Or): + elif isinstance(combinable, (And, Or)): self._extract_columns(combinable.first) self._extract_columns(combinable.second) + elif isinstance(combinable, WhereRaw): + self._where_columns.extend(combinable.querystring.columns) - def where(self, where: Combinable): - if self._where: - self._where = And(self._where, where) - else: - self._where = where + def where(self, *where: Union[Combinable, QueryString]): + for arg in where: + if isinstance(arg, bool): + raise ValueError( + "A boolean value has been passed in to a where clause. " + "This is probably a mistake. For example " + "`.where(MyTable.some_column is None)` instead of " + "`.where(MyTable.some_column.is_null())`." + ) + + if isinstance(arg, QueryString): + # If a raw QueryString is passed in. + arg = WhereRaw(arg.template, *arg.args) + + self._where = And(self._where, arg) if self._where else arg @dataclass class OrderByDelegate: + _order_by: OrderBy = field(default_factory=OrderBy) - _order_by: t.Optional[OrderBy] = None + def get_order_by_columns(self) -> list[Column]: + """ + Used to work out which columns are needed for joins. + """ + return [ + i + for i in itertools.chain( + *[i.columns for i in self._order_by.order_by_items] + ) + if isinstance(i, Column) + ] - def get_order_by_columns(self) -> t.List[Column]: - return list(self._order_by.columns) if self._order_by else [] + def order_by(self, *columns: Union[Column, OrderByRaw], ascending=True): + if len(columns) < 1: + raise ValueError("At least one column must be passed to order_by.") - def order_by(self, *columns: Column, ascending=True): - self._order_by = OrderBy(columns, ascending) + self._order_by.order_by_items.append( + OrderByItem(columns=columns, ascending=ascending) + ) @dataclass class LimitDelegate: - - _limit: t.Optional[Limit] = None + _limit: Optional[Limit] = None _first: bool = False def limit(self, number: int): self._limit = Limit(number) - def first(self): - self.limit(1) - self._first = True - def copy(self) -> LimitDelegate: _limit = self._limit.copy() if self._limit is not None else None return self.__class__(_limit=_limit, _first=self._first) +@dataclass +class AsOfDelegate: + """ + Time travel queries using "As Of" syntax. + Currently supports Cockroach using AS OF SYSTEM TIME. + """ + + _as_of: Optional[AsOf] = None + + def as_of(self, interval: str = "-1s"): + self._as_of = AsOf(interval) + + @dataclass class DistinctDelegate: + _distinct: Distinct = field( + default_factory=lambda: Distinct(enabled=False, on=None) + ) - _distinct: bool = False + def distinct(self, enabled: bool, on: Optional[Sequence[Column]] = None): + if on and not isinstance(on, collections.abc.Sequence): + # Check a sequence is passed in, otherwise the user will get some + # unuseful errors later on. + raise ValueError("`on` must be a sequence of `Column` instances") - def distinct(self): - self._distinct = True + self._distinct = Distinct(enabled=enabled, on=on) @dataclass -class CountDelegate: +class ReturningDelegate: + _returning: Optional[Returning] = None + + def returning(self, columns: Sequence[Column]): + self._returning = Returning(columns=list(columns)) + +@dataclass +class CountDelegate: _count: bool = False def count(self): @@ -167,10 +359,9 @@ def count(self): @dataclass class AddDelegate: + _add: list[Table] = field(default_factory=list) - _add: t.List[Table] = field(default_factory=list) - - def add(self, *instances: Table, table_class: t.Type[Table]): + def add(self, *instances: Table, table_class: type[Table]): for instance in instances: if not isinstance(instance, table_class): raise TypeError("Incompatible type added.") @@ -192,9 +383,10 @@ class OutputDelegate: def output( self, - as_list: t.Optional[bool] = None, - as_json: t.Optional[bool] = None, - load_json: t.Optional[bool] = None, + as_list: Optional[bool] = None, + as_json: Optional[bool] = None, + load_json: Optional[bool] = None, + nested: Optional[bool] = None, ): """ :param as_list: @@ -207,6 +399,8 @@ def output( If True, any JSON fields will have the JSON values returned from the database loaded as Python objects. """ + # We do it like this, so output can be called multiple times, without + # overriding any existing values if they're not specified. if as_list is not None: self._output.as_list = bool(as_list) @@ -216,9 +410,86 @@ def output( if load_json is not None: self._output.load_json = bool(load_json) + if nested is not None: + self._output.nested = bool(nested) + def copy(self) -> OutputDelegate: - _output = self._output.copy() if self._output is not None else None - return self.__class__(_output=_output) + return self.__class__(_output=self._output.copy()) + + +@dataclass +class CallbackDelegate: + """ + Example usage: + + .callback(my_handler_function) + .callback(print, on=CallbackType.success) + .callback(my_handler_coroutine) + .callback([handler1, handler2]) + """ + + _callbacks: dict[CallbackType, list[Callback]] = field( + default_factory=lambda: {kind: [] for kind in CallbackType} + ) + + def callback( + self, + callbacks: Union[Callable, list[Callable]], + *, + on: CallbackType, + ): + if isinstance(callbacks, list): + self._callbacks[on].extend( + Callback(kind=on, target=callback) for callback in callbacks + ) + else: + self._callbacks[on].append(Callback(kind=on, target=callbacks)) + + async def invoke(self, results: Any, *, kind: CallbackType) -> Any: + """ + Utility function that invokes the registered callbacks in the correct + way, handling both sync and async callbacks. Only callbacks of the + given kind are invoked. + Results are passed through the callbacks in the order they were added, + with each callback able to transform them. This function returns the + transformed results. + """ + for callback in self._callbacks[kind]: + if asyncio.iscoroutinefunction(callback.target): + results = await callback.target(results) + else: + results = callback.target(results) + + return results + + +@dataclass +class PrefetchDelegate: + """ + Example usage: + + .prefetch(MyTable.column_a, MyTable.column_b) + """ + + fk_columns: list[ForeignKey] = field(default_factory=list) + + def prefetch(self, *fk_columns: Union[ForeignKey, list[ForeignKey]]): + """ + :param columns: + We accept ``ForeignKey`` and ``List[ForeignKey]`` here, in case + someone passes in a list by accident when using ``all_related()``, + in which case we flatten the list. + + """ + _fk_columns: list[ForeignKey] = [] + for column in fk_columns: + if isinstance(column, list): + _fk_columns.extend(column) + else: + _fk_columns.append(column) + + combined = self.fk_columns + _fk_columns + self.fk_columns = combined @dataclass @@ -229,17 +500,29 @@ class ColumnsDelegate: .columns(MyTable.column_a, MyTable.column_b) """ - selected_columns: t.Sequence[Selectable] = field(default_factory=list) + selected_columns: Sequence[Selectable] = field(default_factory=list) - def columns(self, *columns: Selectable): - combined = list(self.selected_columns) + list(columns) + def columns(self, *columns: Union[Selectable, list[Selectable]]): + """ + :param columns: + We accept ``Selectable`` and ``List[Selectable]`` here, in case + someone passes in a list by accident when using ``all_columns()``, + in which case we flatten the list. + + """ + _columns = flatten(columns) + combined = list(self.selected_columns) + _columns self.selected_columns = combined def remove_secret_columns(self): - self.selected_columns = [ - i for i in self.selected_columns if not isinstance(i, Secret) + non_secret = [ + i + for i in self.selected_columns + if not isinstance(i, Column) or not i._meta.secret ] + self.selected_columns = non_secret + @dataclass class ValuesDelegate: @@ -247,10 +530,10 @@ class ValuesDelegate: Used to specify new column values - primarily used in update queries. """ - table: t.Type[Table] - _values: t.Dict[Column, t.Any] = field(default_factory=dict) + table: type[Table] + _values: dict[Column, Any] = field(default_factory=dict) - def values(self, values: t.Dict[t.Union[Column, str], t.Any]): + def values(self, values: dict[Union[Column, str], Any]): """ Example usage: @@ -265,7 +548,7 @@ def values(self, values: t.Dict[t.Union[Column, str], t.Any]): .values(column_a=1}) """ - cleaned_values: t.Dict[Column, t.Any] = {} + cleaned_values: dict[Column, Any] = {} for key, value in values.items(): if isinstance(key, Column): column = key @@ -280,7 +563,7 @@ def values(self, values: t.Dict[t.Union[Column, str], t.Any]): self._values.update(cleaned_values) - def get_sql_values(self) -> t.List[t.Any]: + def get_sql_values(self) -> list[Any]: """ Convert any Enums into values, and serialise any JSON. """ @@ -297,12 +580,13 @@ class OffsetDelegate: Typically used in conjunction with order_by and limit. - Example usage: + Example usage:: + + .offset(100) - .offset(100) """ - _offset: t.Optional[Offset] = None + _offset: Optional[Offset] = None def offset(self, number: int = 0): self._offset = Offset(number) @@ -312,13 +596,14 @@ def offset(self, number: int = 0): class GroupBy: __slots__ = ("columns",) - columns: t.Sequence[Column] + columns: Sequence[Column] @property def querystring(self) -> QueryString: columns_names = ", ".join( - [i._meta.get_full_name(just_alias=True) for i in self.columns] + i._meta.get_full_name(with_alias=False) for i in self.columns ) + return QueryString(f" GROUP BY {columns_names}") def __str__(self): @@ -328,12 +613,262 @@ def __str__(self): @dataclass class GroupByDelegate: """ - Used to group results - needed when doing aggregation. + Used to group results - needed when doing aggregation:: + + .group_by(Band.name) - .group_by(Band.name) """ - _group_by: t.Optional[GroupBy] = None + _group_by: Optional[GroupBy] = None def group_by(self, *columns: Column): self._group_by = GroupBy(columns=columns) + + +class OnConflictAction(str, Enum): + """ + Specify which action to take on conflict. + """ + + do_nothing = "DO NOTHING" + do_update = "DO UPDATE" + + +@dataclass +class OnConflictItem: + target: Optional[Union[str, Column, tuple[Column, ...]]] = None + action: Optional[OnConflictAction] = None + values: Optional[Sequence[Union[Column, tuple[Column, Any]]]] = None + where: Optional[Combinable] = None + + @property + def target_string(self) -> str: + target = self.target + assert target + + def to_string(value) -> str: + if isinstance(value, Column): + return f'"{value._meta.db_column_name}"' + else: + raise ValueError("OnConflict.target isn't a valid type") + + if isinstance(target, str): + return f'ON CONSTRAINT "{target}"' + elif isinstance(target, Column): + return f"({to_string(target)})" + elif isinstance(target, tuple): + columns_str = ", ".join([to_string(i) for i in target]) + return f"({columns_str})" + else: + raise ValueError("OnConflict.target isn't a valid type") + + @property + def action_string(self) -> QueryString: + action = self.action + if isinstance(action, OnConflictAction): + if action == OnConflictAction.do_nothing: + return QueryString(OnConflictAction.do_nothing.value) + elif action == OnConflictAction.do_update: + values = [] + query = f"{OnConflictAction.do_update.value} SET" + + if not self.values: + raise ValueError("No values specified for `on conflict`") + + for value in self.values: + if isinstance(value, Column): + column_name = value._meta.db_column_name + query += f' "{column_name}"=EXCLUDED."{column_name}",' + elif isinstance(value, tuple): + column = value[0] + value_ = value[1] + if isinstance(column, Column): + column_name = column._meta.db_column_name + else: + raise ValueError("Unsupported column type") + + query += f' "{column_name}"={{}},' + values.append(value_) + + return QueryString(query.rstrip(","), *values) + + raise ValueError("OnConflict.action isn't a valid type") + + @property + def querystring(self) -> QueryString: + query = " ON CONFLICT" + values = [] + + if self.target: + query += f" {self.target_string}" + + if self.action: + query += " {}" + values.append(self.action_string) + + if self.where: + query += " WHERE {}" + values.append(self.where.querystring) + + return QueryString(query, *values) + + def __str__(self) -> str: + return self.querystring.__str__() + + +@dataclass +class OnConflict: + """ + Multiple `ON CONFLICT` statements are allowed - which is why we have this + parent class. + """ + + on_conflict_items: list[OnConflictItem] = field(default_factory=list) + + @property + def querystring(self) -> QueryString: + query = "".join("{}" for i in self.on_conflict_items) + return QueryString( + query, *[i.querystring for i in self.on_conflict_items] + ) + + def __str__(self) -> str: + return self.querystring.__str__() + + +@dataclass +class OnConflictDelegate: + """ + Used with insert queries to specify what to do when a query fails due to + a constraint:: + + .on_conflict(action='DO NOTHING') + + .on_conflict(action='DO UPDATE', values=[Band.popularity]) + + .on_conflict(action='DO UPDATE', values=[(Band.popularity, 1)]) + + """ + + _on_conflict: OnConflict = field(default_factory=OnConflict) + + def on_conflict( + self, + target: Optional[Union[str, Column, tuple[Column, ...]]] = None, + action: Union[ + OnConflictAction, Literal["DO NOTHING", "DO UPDATE"] + ] = OnConflictAction.do_nothing, + values: Optional[Sequence[Union[Column, tuple[Column, Any]]]] = None, + where: Optional[Combinable] = None, + ): + action_: OnConflictAction + if isinstance(action, OnConflictAction): + action_ = action + elif isinstance(action, str): + action_ = OnConflictAction(action.upper()) + else: + raise ValueError("Unrecognised `on conflict` action.") + + if target is None and action_ == OnConflictAction.do_update: + raise ValueError( + "The `target` option must be provided with DO UPDATE." + ) + + if where and action_ == OnConflictAction.do_nothing: + raise ValueError( + "The `where` option can only be used with DO NOTHING." + ) + + self._on_conflict.on_conflict_items.append( + OnConflictItem( + target=target, action=action_, values=values, where=where + ) + ) + + +class LockStrength(str, Enum): + """ + Specify lock strength + + https://www.postgresql.org/docs/current/sql-select.html#SQL-FOR-UPDATE-SHARE + """ + + update = "UPDATE" + no_key_update = "NO KEY UPDATE" + share = "SHARE" + key_share = "KEY SHARE" + + +@dataclass +class LockRows: + __slots__ = ("lock_strength", "nowait", "skip_locked", "of") + + lock_strength: LockStrength + nowait: bool + skip_locked: bool + of: tuple[type[Table], ...] + + def __post_init__(self): + if not isinstance(self.lock_strength, LockStrength): + raise TypeError("lock_strength must be a LockStrength") + if not isinstance(self.nowait, bool): + raise TypeError("nowait must be a bool") + if not isinstance(self.skip_locked, bool): + raise TypeError("skip_locked must be a bool") + if not isinstance(self.of, tuple) or not all( + hasattr(x, "_meta") for x in self.of + ): + raise TypeError("of must be a tuple of Table") + if self.nowait and self.skip_locked: + raise TypeError( + "The nowait option cannot be used with skip_locked" + ) + + @property + def querystring(self) -> QueryString: + sql = f" FOR {self.lock_strength.value}" + if self.of: + tables = ", ".join( + i._meta.get_formatted_tablename() for i in self.of + ) + sql += " OF " + tables + if self.nowait: + sql += " NOWAIT" + if self.skip_locked: + sql += " SKIP LOCKED" + + return QueryString(sql) + + def __str__(self) -> str: + return self.querystring.__str__() + + +@dataclass +class LockRowsDelegate: + + _lock_rows: Optional[LockRows] = None + + def lock_rows( + self, + lock_strength: Union[ + LockStrength, + Literal[ + "UPDATE", + "NO KEY UPDATE", + "KEY SHARE", + "SHARE", + ], + ] = LockStrength.update, + nowait=False, + skip_locked=False, + of: tuple[type[Table], ...] = (), + ): + lock_strength_: LockStrength + if isinstance(lock_strength, LockStrength): + lock_strength_ = lock_strength + elif isinstance(lock_strength, str): + lock_strength_ = LockStrength(lock_strength.upper()) + else: + raise ValueError("Unrecognised `lock_strength` value.") + + self._lock_rows = LockRows(lock_strength_, nowait, skip_locked, of) diff --git a/piccolo/query/operators/__init__.py b/piccolo/query/operators/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/piccolo/query/operators/json.py b/piccolo/query/operators/json.py new file mode 100644 index 000000000..be7529135 --- /dev/null +++ b/piccolo/query/operators/json.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional, Union + +from piccolo.querystring import QueryString +from piccolo.utils.encoding import dump_json + +if TYPE_CHECKING: + from piccolo.columns.column_types import JSON + + +class JSONQueryString(QueryString): + + def clean_value(self, value: Any): + if not isinstance(value, (str, QueryString)): + value = dump_json(value) + return value + + def __eq__(self, value) -> QueryString: # type: ignore[override] + value = self.clean_value(value) + return QueryString("{} = {}", self, value) + + def __ne__(self, value) -> QueryString: # type: ignore[override] + value = self.clean_value(value) + return QueryString("{} != {}", self, value) + + def eq(self, value) -> QueryString: + return self.__eq__(value) + + def ne(self, value) -> QueryString: + return self.__ne__(value) + + +class GetChildElement(JSONQueryString): + """ + Allows you to get a child element from a JSON object. + + You can access this via the ``arrow`` function on ``JSON`` and ``JSONB`` + columns. + + """ + + def __init__( + self, + identifier: Union[JSON, QueryString], + key: Union[str, int, QueryString], + alias: Optional[str] = None, + ): + if isinstance(key, int): + # asyncpg only accepts integer keys if we explicitly mark it as an + # int. + key = QueryString("{}::int", key) + + super().__init__("{} -> {}", identifier, key, alias=alias) + + def arrow(self, key: Union[str, int, QueryString]) -> GetChildElement: + """ + This allows you to drill multiple levels deep into a JSON object if + needed. + + For example:: + + >>> await RecordingStudio.select( + ... RecordingStudio.name, + ... RecordingStudio.facilities.arrow( + ... "instruments" + ... ).arrow( + ... "drum_kits" + ... ).as_alias("drum_kits") + ... ).output(load_json=True) + [ + {'name': 'Abbey Road', 'drum_kits': 2}, + {'name': 'Electric Lady', 'drum_kits': 3} + ] + + """ + return GetChildElement(identifier=self, key=key, alias=self._alias) + + def __getitem__( + self, value: Union[str, int, QueryString] + ) -> GetChildElement: + return GetChildElement(identifier=self, key=value, alias=self._alias) + + +class GetElementFromPath(JSONQueryString): + """ + Allows you to retrieve an element from a JSON object by specifying a path. + It can be several levels deep. + + You can access this via the ``from_path`` function on ``JSON`` and + ``JSONB`` columns. + + """ + + def __init__( + self, + identifier: Union[JSON, QueryString], + path: list[Union[str, int]], + alias: Optional[str] = None, + ): + """ + :param path: + For example: ``["technician", 0, "name"]``. + + """ + super().__init__( + "{} #> {}", + identifier, + [str(i) if isinstance(i, int) else i for i in path], + alias=alias, + ) diff --git a/piccolo/query/proxy.py b/piccolo/query/proxy.py new file mode 100644 index 000000000..03b9827ca --- /dev/null +++ b/piccolo/query/proxy.py @@ -0,0 +1,69 @@ +import inspect +from collections.abc import Generator +from typing import Generic, Optional, TypeVar + +from typing_extensions import Protocol + +from piccolo.query.base import FrozenQuery +from piccolo.utils.sync import run_sync + + +class Runnable(Protocol): + async def run(self, node: Optional[str] = None, in_pool: bool = True): ... + + +QueryType = TypeVar("QueryType", bound=Runnable) +ResponseType = TypeVar("ResponseType") + + +class Proxy(Generic[QueryType, ResponseType]): + def __init__(self, query: QueryType): + self.query = query + + async def run( + self, + node: Optional[str] = None, + in_pool: bool = True, + ) -> ResponseType: + return await self.query.run(node=node, in_pool=in_pool) + + def run_sync(self, *args, **kwargs) -> ResponseType: + return run_sync(self.run(*args, **kwargs)) + + def __await__( + self, + ) -> Generator[None, None, ResponseType]: + """ + If the user doesn't explicity call .run(), proxy to it as a + convenience. + """ + return self.run().__await__() + + def freeze(self): + self.query.freeze() + return FrozenQuery(query=self) + + def __getattr__(self, name: str): + """ + Proxy any attributes to the underlying query, so all of the query + clauses continue to work. + """ + attr = getattr(self.query, name) + + if inspect.ismethod(attr): + # We do this to preserve the fluent interface. + + def proxy(*args, **kwargs): + response = attr(*args, **kwargs) + if isinstance(response, self.query.__class__): + self.query = response + return self + else: + return response + + return proxy + else: + return attr + + def __str__(self) -> str: + return self.query.__str__() diff --git a/piccolo/querystring.py b/piccolo/querystring.py index 3c6b25c57..ea4b686c8 100644 --- a/piccolo/querystring.py +++ b/piccolo/querystring.py @@ -1,24 +1,51 @@ from __future__ import annotations -import datetime -import typing as t +from abc import ABCMeta, abstractmethod +from collections.abc import Sequence from dataclasses import dataclass +from datetime import datetime +from importlib.util import find_spec from string import Formatter +from typing import TYPE_CHECKING, Any, Optional -if t.TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover + from piccolo.columns import Column from piccolo.table import Table +from uuid import UUID -@dataclass -class Unquoted: +if find_spec("asyncpg"): + from asyncpg.pgproto.pgproto import UUID as apgUUID +else: + apgUUID = UUID + + +class Selectable(metaclass=ABCMeta): """ - Used when we want the value to be unquoted because it's a Postgres - keyword - for example DEFAULT. + Anything which inherits from this can be used in a select query. """ - __slots__ = ("value",) + __slots__ = ("_alias",) + + _alias: Optional[str] + + @abstractmethod + def get_select_string( + self, engine_type: str, with_alias: bool = True + ) -> QueryString: + """ + In a query, what to output after the select statement - could be a + column name, a sub query, a function etc. For a column it will be the + column name. + """ + raise NotImplementedError() - value: str + def as_alias(self, alias: str) -> Selectable: + """ + Allows column names to be changed in the result of a select. + """ + self._alias = alias + return self @dataclass @@ -28,7 +55,7 @@ class Fragment: no_arg: bool = False -class QueryString: +class QueryString(Selectable): """ When we're composing complex queries, we're combining QueryStrings, rather than concatenating strings directly. The reason for this is QueryStrings @@ -42,42 +69,81 @@ class QueryString: "query_type", "table", "_frozen_compiled_strings", + "columns", ) def __init__( self, template: str, - *args: t.Any, + *args: Any, query_type: str = "generic", - table: t.Optional[t.Type[Table]] = None, + table: Optional[type[Table]] = None, + alias: Optional[str] = None, ) -> None: """ - Example template: "WHERE {} = {}" + :param template: + The SQL query, with curly brackets as placeholders for any values:: + + "WHERE {} = {}" + + :param args: + The values to insert (one value is needed for each set of curly + braces in the template). + :param query_type: + The query type is sometimes used by the engine to modify how the + query is run. For example, INSERT queries on old SQLite versions. + :param table: + Sometimes the ``piccolo.engine.base.Engine`` needs access to the + table that the query is being run on. - The query type is sometimes used by the engine to modify how the query - is run. """ self.template = template - self.args = args self.query_type = query_type self.table = table - self._frozen_compiled_strings: t.Optional[ - t.Tuple[str, t.List[t.Any]] - ] = None + self._frozen_compiled_strings: Optional[tuple[str, list[Any]]] = None + self._alias = alias + self.args, self.columns = self.process_args(args) + + def process_args( + self, args: Sequence[Any] + ) -> tuple[Sequence[Any], Sequence[Column]]: + """ + If a Column is passed in, we convert it to the name of the column + (including joins). + """ + from piccolo.columns import Column + + processed_args = [] + columns = [] + + for arg in args: + if isinstance(arg, Column): + columns.append(arg) + arg = QueryString( + f"{arg._meta.get_full_name(with_alias=False)}" + ) + elif isinstance(arg, QueryString): + columns.extend(arg.columns) + + processed_args.append(arg) + + return (processed_args, columns) + + def as_alias(self, alias: str) -> QueryString: + self._alias = alias + return self def __str__(self): """ - The SQL returned by the __str__ method isn't used directly in queries - - it's just a usability feature. + The SQL returned by the ``__str__`` method isn't used directly in + queries - it's just a usability feature. """ _, bundled, combined_args = self.bundle( start_index=1, bundled=[], combined_args=[] ) template = "".join( - [ - fragment.prefix + ("" if fragment.no_arg else "{}") - for fragment in bundled - ] + fragment.prefix + ("" if fragment.no_arg else "{}") + for fragment in bundled ) # Do some basic type conversion here. @@ -87,8 +153,10 @@ def __str__(self): if _type == str: converted_args.append(f"'{arg}'") elif _type == datetime: - dt_string = arg.isoformat().replace("T", " ") + dt_string = arg.isoformat() converted_args.append(f"'{dt_string}'") + elif _type == UUID or _type == apgUUID: + converted_args.append(f"'{arg}'") elif arg is None: converted_args.append("null") else: @@ -99,8 +167,8 @@ def __str__(self): def bundle( self, start_index: int = 1, - bundled: t.Optional[t.List[Fragment]] = None, - combined_args: t.Optional[t.List] = None, + bundled: Optional[list[Fragment]] = None, + combined_args: Optional[list] = None, ): # Split up the string, separating by {}. fragments = [ @@ -118,7 +186,7 @@ def bundle( fragment.no_arg = True bundled.append(fragment) else: - if type(value) == self.__class__: + if isinstance(value, QueryString): fragment.no_arg = True bundled.append(fragment) @@ -137,7 +205,7 @@ def bundle( def compile_string( self, engine_type: str = "postgres" - ) -> t.Tuple[str, t.List[t.Any]]: + ) -> tuple[str, list[Any]]: """ Compiles the template ready for the engine - keeping the arguments separate from the template. @@ -148,21 +216,19 @@ def compile_string( _, bundled, combined_args = self.bundle( start_index=1, bundled=[], combined_args=[] ) - if engine_type == "postgres": + if engine_type in ("postgres", "cockroach"): string = "".join( - [ - fragment.prefix - + ("" if fragment.no_arg else f"${fragment.index}") - for fragment in bundled - ] + fragment.prefix + + ("" if fragment.no_arg else f"${fragment.index}") + for fragment in bundled ) + elif engine_type == "sqlite": string = "".join( - [ - fragment.prefix + ("" if fragment.no_arg else "?") - for fragment in bundled - ] + fragment.prefix + ("" if fragment.no_arg else "?") + for fragment in bundled ) + else: raise Exception("Engine type not recognised") @@ -172,3 +238,89 @@ def freeze(self, engine_type: str = "postgres"): self._frozen_compiled_strings = self.compile_string( engine_type=engine_type ) + + ########################################################################### + + def get_select_string( + self, engine_type: str, with_alias: bool = True + ) -> QueryString: + if with_alias and self._alias: + return QueryString("{} AS " + f'"{self._alias}"', self) + else: + return self + + def get_where_string(self, engine_type: str) -> QueryString: + return self.get_select_string( + engine_type=engine_type, with_alias=False + ) + + ########################################################################### + # Basic logic + + def __eq__(self, value) -> QueryString: # type: ignore[override] + if value is None: + return QueryString("{} IS NULL", self) + else: + return QueryString("{} = {}", self, value) + + def __ne__(self, value) -> QueryString: # type: ignore[override] + if value is None: + return QueryString("{} IS NOT NULL", self, value) + else: + return QueryString("{} != {}", self, value) + + def eq(self, value) -> QueryString: + return self.__eq__(value) + + def ne(self, value) -> QueryString: + return self.__ne__(value) + + def __add__(self, value) -> QueryString: + return QueryString("{} + {}", self, value) + + def __sub__(self, value) -> QueryString: + return QueryString("{} - {}", self, value) + + def __gt__(self, value) -> QueryString: + return QueryString("{} > {}", self, value) + + def __ge__(self, value) -> QueryString: + return QueryString("{} >= {}", self, value) + + def __lt__(self, value) -> QueryString: + return QueryString("{} < {}", self, value) + + def __le__(self, value) -> QueryString: + return QueryString("{} <= {}", self, value) + + def __truediv__(self, value) -> QueryString: + return QueryString("{} / {}", self, value) + + def __mul__(self, value) -> QueryString: + return QueryString("{} * {}", self, value) + + def __pow__(self, value) -> QueryString: + return QueryString("{} ^ {}", self, value) + + def __mod__(self, value) -> QueryString: + return QueryString("{} % {}", self, value) + + def is_in(self, value) -> QueryString: + return QueryString("{} IN {}", self, value) + + def not_in(self, value) -> QueryString: + return QueryString("{} NOT IN {}", self, value) + + def like(self, value: str) -> QueryString: + return QueryString("{} LIKE {}", self, value) + + def ilike(self, value: str) -> QueryString: + return QueryString("{} ILIKE {}", self, value) + + +class Unquoted(QueryString): + """ + This is deprecated - just use QueryString directly. + """ + + pass diff --git a/piccolo/schema.py b/piccolo/schema.py new file mode 100644 index 000000000..8d949a700 --- /dev/null +++ b/piccolo/schema.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +import abc +from typing import Optional, cast + +from piccolo.engine.base import Engine +from piccolo.engine.finder import engine_finder +from piccolo.querystring import QueryString +from piccolo.utils.sync import run_sync + + +class SchemaDDLBase(abc.ABC): + db: Engine + + @property + @abc.abstractmethod + def ddl(self) -> str: + pass + + def __await__(self): + return self.run().__await__() + + async def run(self, in_pool=True): + return await self.db.run_ddl(self.ddl, in_pool=in_pool) + + def run_sync(self, *args, **kwargs): + return run_sync(self.run(*args, **kwargs)) + + def __str__(self) -> str: + return self.ddl.__str__() + + +class CreateSchema(SchemaDDLBase): + def __init__( + self, + schema_name: str, + *, + if_not_exists: bool, + db: Engine, + ): + self.schema_name = schema_name + self.if_not_exists = if_not_exists + self.db = db + + async def run(self, *args, **kwargs): + if self.schema_name == "public" or self.schema_name is None: + return + + return await super().run(self, *args, **kwargs) + + @property + def ddl(self) -> str: + query = "CREATE SCHEMA" + if self.if_not_exists: + query += " IF NOT EXISTS" + query += f' "{self.schema_name}"' + + return query + + +class DropSchema(SchemaDDLBase): + def __init__( + self, + schema_name: str, + *, + if_exists: bool, + cascade: bool, + db: Engine, + ): + self.schema_name = schema_name + self.if_exists = if_exists + self.cascade = cascade + self.db = db + + @property + def ddl(self) -> str: + query = "DROP SCHEMA" + if self.if_exists: + query += " IF EXISTS" + query += f' "{self.schema_name}"' + + if self.cascade: + query += " CASCADE" + + return query + + +class RenameSchema(SchemaDDLBase): + def __init__( + self, + schema_name: str, + new_schema_name: str, + db: Engine, + ): + self.schema_name = schema_name + self.new_schema_name = new_schema_name + self.db = db + + @property + def ddl(self): + return ( + f'ALTER SCHEMA "{self.schema_name}" ' + f'RENAME TO "{self.new_schema_name}"' + ) + + +class MoveTable(SchemaDDLBase): + def __init__( + self, + table_name: str, + new_schema: str, + db: Engine, + current_schema: Optional[str] = None, + ): + self.table_name = table_name + self.current_schema = current_schema + self.new_schema = new_schema + self.db = db + + @property + def ddl(self): + table_name = f'"{self.table_name}"' + if self.current_schema: + table_name = f'"{self.current_schema}".{table_name}' + + return f'ALTER TABLE {table_name} SET SCHEMA "{self.new_schema}"' + + +class ListTables: + def __init__(self, db: Engine, schema_name: str): + self.db = db + self.schema_name = schema_name + + async def run(self) -> list[str]: + response = cast( + list[dict], + await self.db.run_querystring( + QueryString( + """ + SELECT table_name + FROM information_schema.tables + WHERE table_schema = {} + """, + self.schema_name, + ) + ), + ) + return [i["table_name"] for i in response] + + def run_sync(self): + return run_sync(self.run()) + + def __await__(self): + return self.run().__await__() + + +class ListSchemas: + def __init__(self, db: Engine): + self.db = db + + async def run(self) -> list[str]: + response = cast( + list[dict], + await self.db.run_querystring( + QueryString( + "SELECT schema_name FROM information_schema.schemata" + ) + ), + ) + return [i["schema_name"] for i in response] + + def run_sync(self): + return run_sync(self.run()) + + def __await__(self): + return self.run().__await__() + + +class SchemaManager: + def __init__(self, db: Optional[Engine] = None): + """ + A useful utility class for interacting with schemas. + + :param db: + Used to execute the database queries. If not specified, we try and + import it from ``piccolo_conf.py``. + """ + db = db or engine_finder() + + if db is None: + raise ValueError("The DB can't be found.") + + self.db = db + + def create_schema( + self, schema_name: str, *, if_not_exists: bool = True + ) -> CreateSchema: + """ + Creates the specified schema:: + + >>> await SchemaManager().create_schema(schema_name="music") + + :param schema_name: + The name of the schema to create. + :param if_not_exists: + No error will be raised if the schema already exists. + + """ + return CreateSchema( + schema_name=schema_name, + if_not_exists=if_not_exists, + db=self.db, + ) + + def drop_schema( + self, + schema_name: str, + *, + if_exists: bool = True, + cascade: bool = False, + ) -> DropSchema: + """ + Drops the specified schema:: + + >>> await SchemaManager().drop_schema(schema_name="music") + + :param schema_name: + The name of the schema to drop. + :param if_exists: + No error will be raised if the schema doesn't exist. + :param cascade: + If ``True`` then it will automatically drop the tables within the + schema. + + """ + return DropSchema( + schema_name=schema_name, + if_exists=if_exists, + cascade=cascade, + db=self.db, + ) + + def rename_schema( + self, schema_name: str, new_schema_name: str + ) -> RenameSchema: + """ + Rename the schema:: + + >>> await SchemaManager().rename_schema( + ... schema_name="music", + ... new_schema_name="music_info" + ... ) + + :param schema_name: + The current name of the schema. + :param new_schema_name: + What to rename the schema to. + + """ + return RenameSchema( + schema_name=schema_name, + new_schema_name=new_schema_name, + db=self.db, + ) + + def move_table( + self, + table_name: str, + new_schema: str, + current_schema: Optional[str] = None, + ) -> MoveTable: + """ + Moves a table to a different schema:: + + >>> await SchemaManager().move_schema( + ... table_name='my_table', + ... new_schema='schema_1' + ... ) + + :param table_name: + The name of the table to move. + :param new_schema: + The name of the scheam you want to move the table too. + :current_schema: + If not specified, ``'public'`` is assumed. + + """ + return MoveTable( + table_name=table_name, + new_schema=new_schema, + current_schema=current_schema, + db=self.db, + ) + + def list_tables(self, schema_name: str) -> ListTables: + """ + Returns the name of each table in the given schema:: + + >>> await SchemaManager().list_tables(schema_name="music") + ['band', 'manager'] + + :param schema_name: + List the tables in this schema. + + """ + return ListTables(db=self.db, schema_name=schema_name) + + def list_schemas(self) -> ListSchemas: + """ + Returns the name of each schema in the database:: + + >>> await SchemaManager().list_schemas() + ['public', 'schema_1'] + + """ + return ListSchemas(db=self.db) diff --git a/piccolo/table.py b/piccolo/table.py index 1b05ace58..bdfda2cdd 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -3,24 +3,33 @@ import inspect import itertools import types -import typing as t +import warnings +from collections.abc import Sequence from dataclasses import dataclass, field +from graphlib import TopologicalSorter +from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload from piccolo.columns import Column from piccolo.columns.column_types import ( JSON, JSONB, + Array, + Email, ForeignKey, - Secret, + ReferencedTable, Serial, ) from piccolo.columns.defaults.base import Default from piccolo.columns.indexes import IndexMethod -from piccolo.columns.readable import Readable -from piccolo.columns.reference import ( - LAZY_COLUMN_REFERENCES, - LazyTableReference, +from piccolo.columns.m2m import ( + M2M, + M2MAddRelated, + M2MGetRelated, + M2MRemoveRelated, ) +from piccolo.columns.readable import Readable +from piccolo.columns.reference import LAZY_COLUMN_REFERENCES +from piccolo.custom_types import TableInstance from piccolo.engine import Engine, engine_finder from piccolo.query import ( Alter, @@ -38,15 +47,25 @@ ) from piccolo.query.methods.create_index import CreateIndex from piccolo.query.methods.indexes import Indexes -from piccolo.querystring import QueryString, Unquoted +from piccolo.query.methods.objects import GetRelated, UpdateSelf +from piccolo.query.methods.refresh import Refresh +from piccolo.querystring import QueryString from piccolo.utils import _camel_to_snake from piccolo.utils.sql_values import convert_to_sql_value +from piccolo.utils.sync import run_sync +from piccolo.utils.warnings import colored_warning -if t.TYPE_CHECKING: - from piccolo.columns import Selectable - +if TYPE_CHECKING: # pragma: no cover + from piccolo.querystring import Selectable PROTECTED_TABLENAMES = ("user",) +TABLENAME_WARNING = ( + "We recommend giving your table a different name as `{tablename}` is a " + "reserved keyword. It should still work, but avoid if possible." +) + + +TABLE_REGISTRY: list[type[Table]] = [] @dataclass @@ -56,28 +75,53 @@ class TableMeta: """ tablename: str = "" - columns: t.List[Column] = field(default_factory=list) - default_columns: t.List[Column] = field(default_factory=list) - non_default_columns: t.List[Column] = field(default_factory=list) - foreign_key_columns: t.List[ForeignKey] = field(default_factory=list) + columns: list[Column] = field(default_factory=list) + default_columns: list[Column] = field(default_factory=list) + non_default_columns: list[Column] = field(default_factory=list) + array_columns: list[Array] = field(default_factory=list) + email_columns: list[Email] = field(default_factory=list) + foreign_key_columns: list[ForeignKey] = field(default_factory=list) primary_key: Column = field(default_factory=Column) - json_columns: t.List[t.Union[JSON, JSONB]] = field(default_factory=list) - secret_columns: t.List[Secret] = field(default_factory=list) - tags: t.List[str] = field(default_factory=list) - help_text: t.Optional[str] = None - _db: t.Optional[Engine] = None + json_columns: list[Union[JSON, JSONB]] = field(default_factory=list) + secret_columns: list[Column] = field(default_factory=list) + auto_update_columns: list[Column] = field(default_factory=list) + tags: list[str] = field(default_factory=list) + help_text: Optional[str] = None + _db: Optional[Engine] = None + m2m_relationships: list[M2M] = field(default_factory=list) + schema: Optional[str] = None # Records reverse foreign key relationships - i.e. when the current table # is the target of a foreign key. Used by external libraries such as # Piccolo API. - _foreign_key_references: t.List[ForeignKey] = field(default_factory=list) + _foreign_key_references: list[ForeignKey] = field(default_factory=list) - @property - def foreign_key_references(self) -> t.List[ForeignKey]: - foreign_keys: t.List[ForeignKey] = [] - for reference in self._foreign_key_references: - foreign_keys.append(reference) + def get_formatted_tablename( + self, include_schema: bool = True, quoted: bool = True + ) -> str: + """ + Returns the tablename, in the desired format. + + :param include_schema: + If ``True``, the Postgres schema is included. For example, + 'my_schema.my_table'. + :param quote: + If ``True``, the name is wrapped in double quotes. For example, + '"my_schema"."my_table"'. + + """ + components = [self.tablename] + if include_schema and self.schema: + components.insert(0, self.schema) + + if quoted: + return ".".join(f'"{i}"' for i in components) + else: + return ".".join(components) + @property + def foreign_key_references(self) -> list[ForeignKey]: + foreign_keys: list[ForeignKey] = list(self._foreign_key_references) lazy_column_references = LAZY_COLUMN_REFERENCES.for_tablename( tablename=self.tablename ) @@ -95,6 +139,16 @@ def db(self) -> Engine: return self._db + @db.setter + def db(self, value: Engine): + self._db = value + + def refresh_db(self) -> None: + engine = engine_finder() + if engine is None: + raise ValueError("The engine can't be found") + self.db = engine + def get_column_by_name(self, name: str) -> Column: """ Returns a column which matches the given name. It will try and follow @@ -112,20 +166,60 @@ def get_column_by_name(self, name: str) -> Column: for reference_name in components[1:]: try: column_object = getattr(column_object, reference_name) - except AttributeError: + except AttributeError as e: raise ValueError( f"Unable to find column - {reference_name}" - ) + ) from e return column_object + def get_auto_update_values(self) -> dict[Column, Any]: + """ + If columns have ``auto_update`` defined, then we retrieve these values. + """ + output: dict[Column, Any] = {} + for column in self.auto_update_columns: + value = column._meta.auto_update + if callable(value): + value = value() + output[column] = value + return output + class TableMetaclass(type): - def __str__(cls): - return cls._table_str() + def __str__(cls) -> str: + return cls._table_str() # type: ignore + + def __repr__(cls): + """ + We override this, because by default Python will output something + like:: + + >>> repr(MyTable) + + + It's a very common pattern in Piccolo and its sister libraries to + have ``Table`` class types as default values:: + + # `SessionsBase` is a `Table` subclass: + def session_auth( + session_table: type[SessionsBase] = SessionsBase + ): + ... + + This looks terrible in Sphinx's autodoc output, as Python's default + repr contains angled brackets, which breaks the HTML output. So we just + output the name instead. The user can still easily find which module a + ``Table`` subclass belongs to by using ``MyTable.__module__``. + + """ + return cls.__name__ class Table(metaclass=TableMetaclass): + """ + The class represents a database table. An instance represents a row. + """ # These are just placeholder values, so type inference isn't confused - the # actual values are set in __init_subclass__. @@ -133,11 +227,12 @@ class Table(metaclass=TableMetaclass): def __init_subclass__( cls, - tablename: t.Optional[str] = None, - db: t.Optional[Engine] = None, - tags: t.List[str] = [], - help_text: t.Optional[str] = None, - ): + tablename: Optional[str] = None, + db: Optional[Engine] = None, + tags: Optional[list[str]] = None, + help_text: Optional[str] = None, + schema: Optional[str] = None, + ): # sourcery no-metrics """ Automatically populate the _meta, which includes the tablename, and columns. @@ -155,23 +250,35 @@ def __init_subclass__( A user friendly description of what the table is used for. It isn't used in the database, but will be used by tools such a Piccolo Admin for tooltips. + :param schema: + The Postgres schema to use for this table. """ - tablename = tablename if tablename else _camel_to_snake(cls.__name__) - - if tablename in PROTECTED_TABLENAMES: - raise ValueError( - f"{tablename} is a protected name, please give your table a " - "different name." + if tags is None: + tags = [] + tablename = tablename or _camel_to_snake(cls.__name__) + + if "." in tablename: + warnings.warn( + "There's a '.' in the tablename - please use the `schema` " + "argument instead." ) + schema, tablename = tablename.split(".", maxsplit=1) - columns: t.List[Column] = [] - default_columns: t.List[Column] = [] - non_default_columns: t.List[Column] = [] - foreign_key_columns: t.List[ForeignKey] = [] - secret_columns: t.List[Secret] = [] - json_columns: t.List[t.Union[JSON, JSONB]] = [] - primary_key: t.Optional[Column] = None + if tablename in PROTECTED_TABLENAMES: + warnings.warn(TABLENAME_WARNING.format(tablename=tablename)) + + columns: list[Column] = [] + default_columns: list[Column] = [] + non_default_columns: list[Column] = [] + array_columns: list[Array] = [] + foreign_key_columns: list[ForeignKey] = [] + secret_columns: list[Column] = [] + json_columns: list[Union[JSON, JSONB]] = [] + email_columns: list[Email] = [] + auto_update_columns: list[Column] = [] + primary_key: Optional[Column] = None + m2m_relationships: list[M2M] = [] attribute_names = itertools.chain( *[i.__dict__.keys() for i in reversed(cls.__mro__)] @@ -200,8 +307,12 @@ def __init_subclass__( column._meta._name = attribute_name column._meta._table = cls - if isinstance(column, Secret): - secret_columns.append(column) + if isinstance(column, Array): + column._setup_base_column(table_class=cls) + array_columns.append(column) + + if isinstance(column, Email): + email_columns.append(column) if isinstance(column, ForeignKey): foreign_key_columns.append(column) @@ -209,6 +320,17 @@ def __init_subclass__( if isinstance(column, (JSON, JSONB)): json_columns.append(column) + if column._meta.secret: + secret_columns.append(column) + + if column._meta.auto_update is not ...: + auto_update_columns.append(column) + + if isinstance(attribute, M2M): + attribute._meta._name = attribute_name + attribute._meta._table = cls + m2m_relationships.append(attribute) + if not primary_key: primary_key = cls._create_serial_primary_key() setattr(cls, "id", primary_key) @@ -221,82 +343,89 @@ def __init_subclass__( columns=columns, default_columns=default_columns, non_default_columns=non_default_columns, + array_columns=array_columns, + email_columns=email_columns, primary_key=primary_key, foreign_key_columns=foreign_key_columns, json_columns=json_columns, secret_columns=secret_columns, + auto_update_columns=auto_update_columns, tags=tags, help_text=help_text, _db=db, + m2m_relationships=m2m_relationships, + schema=schema, ) for foreign_key_column in foreign_key_columns: - params = foreign_key_column._meta.params - references = params["references"] - - if isinstance(references, str): - if references == "self": - references = cls - else: - if "." in references: - # Don't allow relative modules - this may change in - # the future. - if references.startswith("."): - raise ValueError("Relative imports aren't allowed") - - module_path, table_class_name = references.rsplit( - ".", maxsplit=1 - ) - else: - table_class_name = references - module_path = cls.__module__ - - references = LazyTableReference( - table_class_name=table_class_name, - module_path=module_path, - ) - - is_lazy = isinstance(references, LazyTableReference) - is_table_class = inspect.isclass(references) and issubclass( - references, Table + # ForeignKey columns require additional setup based on their + # parent Table. + foreign_key_setup_response = foreign_key_column._setup( + table_class=cls ) - - if is_lazy or is_table_class: - foreign_key_column._foreign_key_meta.references = references - else: - raise ValueError( - "Error - ``references`` must be a ``Table`` subclass, or " - "a ``LazyTableReference`` instance." - ) - - # Record the reverse relationship on the target table. - if is_table_class: - references._meta._foreign_key_references.append( - foreign_key_column - ) - elif is_lazy: + if foreign_key_setup_response.is_lazy: LAZY_COLUMN_REFERENCES.foreign_key_columns.append( foreign_key_column ) - # Allow columns on the referenced table to be accessed via - # auto completion. - if is_table_class: - foreign_key_column.set_proxy_columns() + TABLE_REGISTRY.append(cls) def __init__( self, - ignore_missing: bool = False, - exists_in_db: bool = False, + _data: Optional[dict[Column, Any]] = None, + _ignore_missing: bool = False, + _exists_in_db: bool = False, **kwargs, ): """ - Assigns any default column values to the class. + The constructor can be used to assign column values. + + .. note:: + The ``_data``, ``_ignore_missing``, and ``_exists_in_db`` + arguments are prefixed with an underscore to help prevent a clash + with a column name which might be passed in via kwargs. + + :param _data: + There's two ways of passing in the data for each column. Firstly, + you can use kwargs:: + + Band(name="Pythonistas") + + Secondly, you can pass in a dictionary which maps column classes to + values:: + + Band({Band.name: 'Pythonistas'}) + + The advantage of this second approach is it's more strongly typed, + and linters such as flake8 or MyPy will more easily detect typos. + + :param _ignore_missing: + If ``False`` a ``ValueError`` will be raised if any column values + haven't been provided. + :param _exists_in_db: + Used internally to track whether this row exists in the database. + """ - self._exists_in_db = exists_in_db + _data = _data or {} + + self._exists_in_db = _exists_in_db + + # This is used by get_or_create to indicate to the user whether it + # was an existing row or not. + self._was_created: Optional[bool] = None for column in self._meta.columns: - value = kwargs.pop(column._meta.name, ...) + value = _data.get(column, ...) + + if kwargs: + if value is ...: + value = kwargs.pop(column._meta.name, ...) + + if value is ...: + value = kwargs.pop( + cast(str, column._meta.db_column_name), ... + ) + if value is ...: value = column.get_default_value() @@ -306,7 +435,7 @@ def __init__( if ( (value is None) and (not column._meta.null) - and not ignore_missing + and not _ignore_missing ): raise ValueError(f"{column._meta.name} wasn't provided") @@ -314,42 +443,122 @@ def __init__( unrecognized = kwargs.keys() if unrecognized: - unrecognised_list = [i for i in unrecognized] + unrecognised_list = list(unrecognized) raise ValueError(f"Unrecognized columns - {unrecognised_list}") @classmethod def _create_serial_primary_key(cls) -> Serial: - pk = Serial(index=False, primary_key=True) + pk = Serial(index=False, primary_key=True, db_column_name="id") pk._meta._name = "id" pk._meta._table = cls return pk + @classmethod + def from_dict( + cls: type[TableInstance], data: dict[str, Any] + ) -> TableInstance: + """ + Used when loading fixtures. It can be overriden by subclasses in case + they have specific logic / validation which needs running when loading + fixtures. + """ + return cls(**data) + ########################################################################### - def save(self) -> t.Union[Insert, Update]: + def save( + self, columns: Optional[Sequence[Union[Column, str]]] = None + ) -> Union[Insert, Update]: """ A proxy to an insert or update query. + + :param columns: + Only the specified columns will be synced back to the database + when doing an update. For example: + + .. code-block:: python + + band = await Band.objects().first() + band.popularity = 2000 + await band.save(columns=[Band.popularity]) + + If ``columns=None`` (the default) then all columns will be synced + back to the database. + """ cls = self.__class__ - if self._exists_in_db: - # pre-existing row - kwargs: t.Dict[Column, t.Any] = { - i: getattr(self, i._meta.name, None) - for i in cls._meta.columns - if i._meta.name != self._meta.primary_key._meta.name - } - return ( - cls.update() - .values(kwargs) # type: ignore - .where( - cls._meta.primary_key - == getattr(self, self._meta.primary_key._meta.name) - ) - ) + # New row - insert + if not self._exists_in_db: + return cls.insert(self).returning(cls._meta.primary_key) + + # Pre-existing row - update + if columns is None: + column_instances = [ + i for i in cls._meta.columns if not i._meta.primary_key + ] else: - return cls.insert().add(self) + column_instances = [ + self._meta.get_column_by_name(i) if isinstance(i, str) else i + for i in columns + ] + + values: dict[Column, Any] = { + i: getattr(self, i._meta.name, None) for i in column_instances + } + + # Assign any `auto_update` values + if cls._meta.auto_update_columns: + auto_update_values = cls._meta.get_auto_update_values() + values.update(auto_update_values) + for column, value in auto_update_values.items(): + setattr(self, column._meta.name, value) + + return cls.update( + values, # type: ignore + # We've already included the `auto_update` columns, so no need + # to do it again: + use_auto_update=False, + ).where( + cls._meta.primary_key + == getattr(self, self._meta.primary_key._meta.name) + ) + + def update_self(self, values: dict[Union[Column, str], Any]) -> UpdateSelf: + """ + This allows the user to update a single object - useful when the values + are derived from the database in some way. + + For example, if we have the following table:: + + class Band(Table): + name = Varchar() + popularity = Integer() + + And we fetch an object:: + + >>> band = await Band.objects().get(name="Pythonistas") + + We could use the typical syntax for updating the object:: + + >>> band.popularity += 1 + >>> await band.save() + + The problem with this, is what if another object has already + incremented ``popularity``? It would overide the value. + + Instead we can do this: + + >>> await band.update_self({ + ... Band.popularity: Band.popularity + 1 + ... }) + + This updates ``popularity`` in the database, and also sets the new + value for ``popularity`` on the object. + + """ + return UpdateSelf(row=self, values=values) def remove(self) -> Delete: """ @@ -362,21 +571,67 @@ def remove(self) -> Delete: setattr(self, self._meta.primary_key._meta.name, None) + self._exists_in_db = False + return self.__class__.delete().where( self.__class__._meta.primary_key == primary_key_value ) - def get_related(self, foreign_key: t.Union[ForeignKey, str]) -> Objects: + def refresh( + self, + columns: Optional[Sequence[Column]] = None, + load_json: bool = False, + ) -> Refresh: + """ + Used to fetch the latest data for this instance from the database. + Modifies the instance in place, but also returns it as a convenience. + + :param columns: + If you only want to refresh certain columns, specify them here. + Otherwise all columns are refreshed. + + :param load_json: + Whether to load ``JSON`` / ``JSONB`` columns as objects, instead of + just a string. + + Example usage:: + + # Get an instance from the database. + instance = await Band.objects.first() + + # Later on we can refresh this instance with the latest data + # from the database, in case it has gotten stale. + await instance.refresh() + + # Alternatively, running it synchronously: + instance.refresh().run_sync() + """ - Used to fetch a Table instance, for the target of a foreign key. + return Refresh(instance=self, columns=columns, load_json=load_json) - band = await Band.objects().first().run() - manager = await band.get_related(Band.manager).run() - >>> print(manager.name) - 'Guido' + @overload + def get_related( + self, foreign_key: ForeignKey[ReferencedTable] + ) -> GetRelated[ReferencedTable]: ... - It can only follow foreign keys one level currently. - i.e. Band.manager, but not Band.manager.x.y.z + @overload + def get_related(self, foreign_key: str) -> GetRelated[Table]: ... + + def get_related( + self, foreign_key: Union[str, ForeignKey[ReferencedTable]] + ) -> GetRelated[ReferencedTable]: + """ + Used to fetch a ``Table`` instance, for the target of a foreign key. + + .. code-block:: python + + band = await Band.objects().first() + manager = await band.get_related(Band.manager) + >>> print(manager.name) + 'Guido' + + It can only follow foreign keys multiple levels deep. For example, + ``Concert.band_1.manager``. """ if isinstance(foreign_key, str): @@ -390,24 +645,145 @@ def get_related(self, foreign_key: t.Union[ForeignKey, str]) -> Objects: "ForeignKey column." ) - column_name = foreign_key._meta.name + return GetRelated(foreign_key=foreign_key, row=self) + + def get_m2m(self, m2m: M2M) -> M2MGetRelated: + """ + Get all matching rows via the join table. + + .. code-block:: python + + >>> band = await Band.objects().get(Band.name == "Pythonistas") + >>> await band.get_m2m(Band.genres) + [, ] + + """ + return M2MGetRelated(row=self, m2m=m2m) + + def add_m2m( + self, + *rows: Table, + m2m: M2M, + extra_column_values: dict[Union[Column, str], Any] = {}, + ) -> M2MAddRelated: + """ + Save the row if it doesn't already exist in the database, and insert + an entry into the joining table. + + .. code-block:: python - references: t.Type[ - Table - ] = foreign_key._foreign_key_meta.resolved_references + >>> band = await Band.objects().get(Band.name == "Pythonistas") + >>> await band.add_m2m( + ... Genre(name="Punk rock"), + ... m2m=Band.genres + ... ) + [{'id': 1}] - return ( - references.objects() - .where( - references._meta.get_column_by_name( - self._meta.primary_key._meta.name + :param extra_column_values: + If the joining table has additional columns besides the two + required foreign keys, you can specify the values for those + additional columns. For example, if this is our joining table: + + .. code-block:: python + + class GenreToBand(Table): + band = ForeignKey(Band) + genre = ForeignKey(Genre) + reason = Text() + + We can provide the ``reason`` value: + + .. code-block:: python + + await band.add_m2m( + Genre(name="Punk rock"), + m2m=Band.genres, + extra_column_values={ + "reason": "Their second album was very punk." + } ) - == getattr(self, column_name) - ) - .first() + + """ + return M2MAddRelated( + target_row=self, + rows=rows, + m2m=m2m, + extra_column_values=extra_column_values, ) - def __setitem__(self, key: str, value: t.Any): + def remove_m2m(self, *rows: Table, m2m: M2M) -> M2MRemoveRelated: + """ + Remove the rows from the joining table. + + .. code-block:: python + + >>> band = await Band.objects().get(Band.name == "Pythonistas") + >>> genre = await Genre.objects().get(Genre.name == "Rock") + >>> await band.remove_m2m( + ... genre, + ... m2m=Band.genres + ... ) + + """ + return M2MRemoveRelated( + target_row=self, + rows=rows, + m2m=m2m, + ) + + def to_dict(self, *columns: Column) -> dict[str, Any]: + """ + A convenience method which returns a dictionary, mapping column names + to values for this table instance. + + .. code-block:: python + + instance = await Manager.objects().get( + Manager.name == 'Guido' + ) + + >>> instance.to_dict() + {'id': 1, 'name': 'Guido'} + + If the columns argument is provided, only those columns are included in + the output. It also works with column aliases. + + .. code-block:: python + + >>> instance.to_dict(Manager.id, Manager.name.as_alias('title')) + {'id': 1, 'title': 'Guido'} + + """ + # Make sure we're only looking at columns for the current table. If + # someone passes in a column for a sub table (for example + # `Band.manager.name`), we need to add `Band.manager` so the nested + # value appears in the output. + filtered_columns = [] + for column in columns: + if column._meta.table == self.__class__: + filtered_columns.append(column) + else: + for parent_column in column._meta.call_chain: + if parent_column._meta.table == self.__class__: + filtered_columns.append(parent_column) + break + + alias_names = { + column._meta.name: column._alias for column in filtered_columns + } + + output = {} + for column in filtered_columns if columns else self._meta.columns: + value = getattr(self, column._meta.name) + if isinstance(value, Table): + value = value.to_dict(*columns) + + output[alias_names.get(column._meta.name) or column._meta.name] = ( + value + ) + return output + + def __setitem__(self, key: str, value: Any): setattr(self, key, value) def __getitem__(self, key: str): @@ -424,16 +800,22 @@ def _get_related_readable(cls, column: ForeignKey) -> Readable: column._foreign_key_meta.resolved_references.get_readable() ) - columns = [getattr(column, i._meta.name) for i in readable.columns] + output_columns = [] + + for readable_column in readable.columns: + output_column = column + for fk in readable_column._meta.call_chain: + output_column = getattr(output_column, fk._meta.name) + output_column = getattr(output_column, readable_column._meta.name) + output_columns.append(output_column) output_name = f"{column._meta.name}_readable" - new_readable = Readable( + return Readable( template=readable.template, - columns=columns, + columns=output_columns, output_name=output_name, ) - return new_readable @classmethod def get_readable(cls) -> Readable: @@ -449,39 +831,181 @@ def querystring(self) -> QueryString: """ Used when inserting rows. """ - args_dict = {} - for col in self._meta.columns: - column_name = col._meta.name - value = convert_to_sql_value(value=self[column_name], column=col) - args_dict[column_name] = value - - def is_unquoted(arg): - return type(arg) == Unquoted - - # Strip out any args which are unquoted. - filtered_args = [i for i in args_dict.values() if not is_unquoted(i)] + args = [ + convert_to_sql_value(value=self[column._meta.name], column=column) + for column in self._meta.columns + ] # If unquoted, dump it straight into the query. - query = ",".join( - [ - args_dict[column._meta.name].value - if is_unquoted(args_dict[column._meta.name]) - else "{}" - for column in self._meta.columns - ] - ) - return QueryString(f"({query})", *filtered_args) + query = ",".join(["{}" for _ in args]) + return QueryString(f"({query})", *args) def __str__(self) -> str: return self.querystring.__str__() def __repr__(self) -> str: - _pk = self._meta.primary_key if self._meta.primary_key else None - return f"<{self.__class__.__name__}: {_pk}>" + pk = ( + None + if not self._exists_in_db + else getattr(self, self._meta.primary_key._meta.name, None) + ) + return f"<{self.__class__.__name__}: {pk}>" + + def __eq__(self, other: Any) -> bool: + """ + Lets us check if two ``Table`` instances represent the same row in the + database, based on their primary key value:: + + band_1 = await Band.objects().where( + Band.name == "Pythonistas" + ).first() + + band_2 = await Band.objects().where( + Band.name == "Pythonistas" + ).first() + + band_3 = await Band.objects().where( + Band.name == "Rustaceans" + ).first() + + >>> band_1 == band_2 + True + + >>> band_1 == band_3 + False + + """ + if not isinstance(other, Table): + # This is the correct way to tell Python that this operation + # isn't supported: + # https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + # Make sure we're comparing the same table. + # There are several ways we could do this (like comparing tablename), + # but this should be OK. + if not isinstance(other, self.__class__): + return False + + pk = self._meta.primary_key + + pk_value = getattr( + self, + pk._meta.name, + ) + + other_pk_value = getattr( + other, + pk._meta.name, + ) + + # Make sure the primary key values are of the correct type. + # We need this for `Serial` columns, which have a `QueryString` + # value until saved in the database. We don't want to use `==` on + # two QueryString values, because QueryString has a custom `__eq__` + # method which doesn't return a boolean. + if isinstance( + pk_value, + pk.value_type, + ) and isinstance( + other_pk_value, + pk.value_type, + ): + return pk_value == other_pk_value + else: + # As a fallback, even if it hasn't been saved in the database, + # an object should still be equal to itself. + return other is self ########################################################################### # Classmethods + @classmethod + def all_related( + cls, exclude: Optional[list[Union[str, ForeignKey]]] = None + ) -> list[ForeignKey]: + """ + Used in conjunction with ``objects`` queries. Just as we can use + ``all_related`` on a ``ForeignKey``, you can also use it for the table + at the root of the query, which will return each related row as a + nested object. For example: + + .. code-block:: python + + concert = await Concert.objects( + Concert.all_related() + ) + + >>> concert.band_1 + + >>> concert.band_2 + + >>> concert.venue + + + This is mostly useful when the table has a lot of foreign keys, and + typing them out by hand would be tedious. It's equivalent to: + + .. code-block:: python + + concert = await Concert.objects( + Concert.venue, + Concert.band_1, + Concert.band_2 + ) + + :param exclude: + You can request all columns, except these. + + """ + if exclude is None: + exclude = [] + excluded_column_names = [ + i._meta.name if isinstance(i, ForeignKey) else i for i in exclude + ] + + return [ + i + for i in cls._meta.foreign_key_columns + if i._meta.name not in excluded_column_names + ] + + @classmethod + def all_columns( + cls, exclude: Optional[Sequence[Union[str, Column]]] = None + ) -> list[Column]: + """ + Used in conjunction with ``select`` queries. Just as we can use + ``all_columns`` to retrieve all of the columns from a related table, + we can also use it at the root of our query to get all of the columns + for the root table. For example: + + .. code-block:: python + + await Band.select( + Band.all_columns(), + Band.manager.all_columns() + ) + + This is mostly useful when the table has a lot of columns, and typing + them out by hand would be tedious. + + :param exclude: + You can request all columns, except these. + + """ + if exclude is None: + exclude = [] + excluded_column_names = [ + i._meta.name if isinstance(i, Column) else i for i in exclude + ] + + return [ + i + for i in cls._meta.columns + if i._meta.name not in excluded_column_names + ] + @classmethod def ref(cls, column_name: str) -> Column: """ @@ -490,7 +1014,9 @@ def ref(cls, column_name: str) -> Column: ever need to do this, but other libraries built on top of Piccolo may need this functionality. - Example: Band.ref('manager.name') + .. code-block:: python + + Band.ref('manager.name') """ local_column_name, reference_column_name = column_name.split(".") @@ -512,49 +1038,63 @@ def ref(cls, column_name: str) -> Column: return _reference_column @classmethod - def insert(cls, *rows: "Table") -> Insert: + def insert( + cls: type[TableInstance], *rows: TableInstance + ) -> Insert[TableInstance]: """ - await Band.insert( - Band(name="Pythonistas", popularity=500, manager=1) - ).run() + Insert rows into the database. + + .. code-block:: python + + await Band.insert( + Band(name="Pythonistas", popularity=500, manager=1) + ) + """ - query = Insert(table=cls) + query = Insert(table=cls).returning(cls._meta.primary_key) if rows: query.add(*rows) return query @classmethod - def raw(cls, sql: str, *args: t.Any) -> Raw: + def raw(cls, sql: str, *args: Any) -> Raw: """ Execute raw SQL queries on the underlying engine - use with caution! - await Band.raw('select * from band').run() + .. code-block:: python + + await Band.raw('select * from band') Or passing in parameters: - await Band.raw("select * from band where name = {}", 'Pythonistas') + .. code-block:: python + + await Band.raw("SELECT * FROM band WHERE name = {}", 'Pythonistas') + """ return Raw(table=cls, querystring=QueryString(sql, *args)) @classmethod def _process_column_args( - cls, *columns: t.Union[Selectable, str] - ) -> t.Sequence[Selectable]: + cls, *columns: Union[Selectable, str] + ) -> Sequence[Selectable]: """ Users can specify some column arguments as either Column instances, or as strings representing the column name, for convenience. Convert any string arguments to column instances. """ return [ - cls._meta.get_column_by_name(column) - if (isinstance(column, str)) - else column + ( + cls._meta.get_column_by_name(column) + if (isinstance(column, str)) + else column + ) for column in columns ] @classmethod def select( - cls, *columns: t.Union[Selectable, str], exclude_secrets=False + cls, *columns: Union[Selectable, str], exclude_secrets=False ) -> Select: """ Get data in the form of a list of dictionaries, with each dictionary @@ -562,13 +1102,19 @@ def select( These are all equivalent: - await Band.select().columns(Band.name).run() - await Band.select(Band.name).run() - await Band.select('name').run() + .. code-block:: python + + await Band.select().columns(Band.name) + await Band.select(Band.name) + await Band.select('name') + + :param exclude_secrets: + If ``True``, any columns with ``secret=True`` are omitted from the + response. For example, we use this for the password column of + :class:`BaseUser `. Even though + the passwords are hashed, you still don't want them being passed + over the network if avoidable. - :param exclude_secrets: If True, any password fields are omitted from - the response. Even though passwords are hashed, you still don't want - them being passed over the network if avoidable. """ _columns = cls._process_column_args(*columns) return Select( @@ -580,26 +1126,37 @@ def delete(cls, force=False) -> Delete: """ Delete rows from the table. - await Band.delete().where(Band.name == 'Pythonistas').run() + .. code-block:: python + + await Band.delete().where(Band.name == 'Pythonistas') + + :param force: + Unless set to ``True``, deletions aren't allowed without a + ``where`` clause, to prevent accidental mass deletions. - Unless 'force' is set to True, deletions aren't allowed without a - 'where' clause, to prevent accidental mass deletions. """ return Delete(table=cls, force=force) @classmethod def create_table( - cls, if_not_exists=False, only_default_columns=False + cls, + if_not_exists=False, + only_default_columns=False, + auto_create_schema: bool = True, ) -> Create: """ Create table, along with all columns. - await Band.create_table().run() + .. code-block:: python + + await Band.create_table() + """ return Create( table=cls, if_not_exists=if_not_exists, only_default_columns=only_default_columns, + auto_create_schema=auto_create_schema, ) @classmethod @@ -607,44 +1164,112 @@ def alter(cls) -> Alter: """ Used to modify existing tables and columns. - await Band.alter().rename_column(Band.popularity, 'rating').run() + .. code-block:: python + + await Band.alter().rename_column(Band.popularity, 'rating') + """ return Alter(table=cls) @classmethod - def objects(cls) -> Objects: + def objects( + cls: type[TableInstance], + *prefetch: Union[ForeignKey, list[ForeignKey]], + ) -> Objects[TableInstance]: """ Returns a list of table instances (each representing a row), which you can modify and then call 'save' on, or can delete by calling 'remove'. - pythonistas = await Band.objects().where( - Band.name == 'Pythonistas' - ).first().run() + .. code-block:: python + + pythonistas = await Band.objects().where( + Band.name == 'Pythonistas' + ).first() - pythonistas.name = 'Pythonistas Reborn' + pythonistas.name = 'Pythonistas Reborn' - await pythonistas.save().run() + await pythonistas.save() + + # Or to remove it from the database: + await pythonistas.remove() + + :param prefetch: + Rather than returning the primary key value of this related table, + a nested object will be returned for the row on the related table. + + .. code-block:: python + + # Without nested + band = await Band.objects().first() + >>> band.manager + 1 + + # With nested + band = await Band.objects(Band.manager).first() + >>> band.manager + - # Or to remove it from the database: - await pythonistas.remove() """ - return Objects(table=cls) + return Objects[TableInstance](table=cls, prefetch=prefetch) @classmethod - def count(cls) -> Count: + def count( + cls, + column: Optional[Column] = None, + distinct: Optional[Sequence[Column]] = None, + ) -> Count: """ - Count the number of matching rows. + Count the number of matching rows:: + + await Band.count().where(Band.popularity > 1000) + + :param column: + If specified, just count rows where this column isn't null. + + :param distinct: + Counts the number of distinct values for these columns. For + example, if we have a concerts table:: + + class Concert(Table): + band = Varchar() + start_date = Date() + + With this data: + + .. table:: + :widths: auto + + =========== ========== + band start_date + =========== ========== + Pythonistas 2023-01-01 + Pythonistas 2023-02-03 + Rustaceans 2023-01-01 + =========== ========== + + Without the ``distinct`` argument, we get the count of all + rows:: + + >>> await Concert.count() + 3 + + To get the number of unique concert dates:: + + >>> await Concert.count(distinct=[Concert.start_date]) + 2 - await Band.count().where(Band.popularity > 1000).run() """ - return Count(table=cls) + return Count(table=cls, column=column, distinct=distinct) @classmethod def exists(cls) -> Exists: """ Use it to check if a row exists, not if the table exists. - await Band.exists().where(Band.name == 'Pythonistas').run() + .. code-block:: python + + await Band.exists().where(Band.name == 'Pythonistas') + """ return Exists(table=cls) @@ -653,13 +1278,20 @@ def table_exists(cls) -> TableExists: """ Check if the table exists in the database. - await Band.table_exists().run() + .. code-block:: python + + await Band.table_exists() + """ return TableExists(table=cls) @classmethod def update( - cls, values: t.Dict[t.Union[Column, str], t.Any] = {}, **kwargs + cls, + values: Optional[dict[Union[Column, str], Any]] = None, + force: bool = False, + use_auto_update: bool = True, + **kwargs, ) -> Update: """ Update rows. @@ -671,38 +1303,56 @@ def update( await Band.update( {Band.name: "Spamalot"} ).where( - Band.name=="Pythonistas" - ).run() + Band.name == "Pythonistas" + ) await Band.update( {"name": "Spamalot"} ).where( - Band.name=="Pythonistas" - ).run() + Band.name == "Pythonistas" + ) await Band.update( name="Spamalot" ).where( - Band.name=="Pythonistas" - ).run() + Band.name == "Pythonistas" + ) + + :param force: + Unless set to ``True``, updates aren't allowed without a + ``where`` clause, to prevent accidental mass overriding of data. + + :param use_auto_update: + Whether to use the ``auto_update`` values on any columns. See + the ``auto_update`` argument on + :class:`Column ` for more information. """ + if values is None: + values = {} values = dict(values, **kwargs) - return Update(table=cls).values(values) + + if use_auto_update and cls._meta.auto_update_columns: + values.update(cls._meta.get_auto_update_values()) # type: ignore + + return Update(table=cls, force=force).values(values) @classmethod def indexes(cls) -> Indexes: """ Returns a list of the indexes for this tables. - await Band.indexes().run() + .. code-block:: python + + await Band.indexes() + """ return Indexes(table=cls) @classmethod def create_index( cls, - columns: t.List[t.Union[Column, str]], + columns: Union[list[Column], list[str]], method: IndexMethod = IndexMethod.btree, if_not_exists: bool = False, ) -> CreateIndex: @@ -710,7 +1360,10 @@ def create_index( Create a table index. If multiple columns are specified, this refers to a multicolumn index, rather than multiple single column indexes. - await Band.create_index([Band.name]).run() + .. code-block:: python + + await Band.create_index([Band.name]) + """ return CreateIndex( table=cls, @@ -721,20 +1374,25 @@ def create_index( @classmethod def drop_index( - cls, columns: t.List[t.Union[Column, str]], if_exists: bool = True + cls, + columns: Union[list[Column], list[str]], + if_exists: bool = True, ) -> DropIndex: """ Drop a table index. If multiple columns are specified, this refers to a multicolumn index, rather than multiple single column indexes. - await Band.drop_index([Band.name]).run() + .. code-block:: python + + await Band.drop_index([Band.name]) + """ return DropIndex(table=cls, columns=columns, if_exists=if_exists) ########################################################################### @classmethod - def _get_index_name(cls, column_names: t.List[str]) -> str: + def _get_index_name(cls, column_names: list[str]) -> str: """ Generates an index name from the table name and column names. """ @@ -743,32 +1401,92 @@ def _get_index_name(cls, column_names: t.List[str]) -> str: ########################################################################### @classmethod - def _table_str(cls, abbreviated=False): + def _table_str( + cls, + abbreviated: bool = False, + excluded_params: Optional[list[str]] = None, + ): """ Returns a basic string representation of the table and its columns. Used by the playground. :param abbreviated: - If True, a very high level representation is printed out. + If True, a very high level representation is printed out (it just + shows any non-default values). + :param excluded_params: + Lets us find a middle ground between outputting every kwarg, and + the abbreviated version with very few kwargs. For example + `['index_method']`, if we want to show all kwargs but index_method. """ + from piccolo.apps.migrations.auto.serialisation import ( + SerialisedEnumTypeDefinition, + serialise_params, + ) + + if excluded_params is None: + excluded_params = [] + spacer = "\n " columns = [] + extra_definitions = [] for col in cls._meta.columns: - params: t.List[str] = [] + base_column_defaults = { + key: value.default + for key, value in inspect.signature(Column).parameters.items() + } + column_defaults = { + key: value.default + for key, value in inspect.signature( + col.__class__ + ).parameters.items() + } + defaults = {**base_column_defaults, **column_defaults} + + params = {} for key, value in col._meta.params.items(): - _value: str = "" - if inspect.isclass(value): - _value = value.__name__ - params.append(f"{key}={_value}") - else: - _value = repr(value) - if not abbreviated: - params.append(f"{key}={_value}") - params_string = ", ".join(params) + if key in excluded_params: + continue + + if abbreviated: + # If the value is just the default one, don't include it. + if defaults.get(key, ...) == value: + continue + + # If db_column is the same as the column name then don't + # include it - it does nothing. + if key == "db_column_name" and value == col._meta.name: + continue + + params[key] = value + + serialised_params = serialise_params(params, inline_enums=False) + params_string = ", ".join( + f"{key}={repr(value)}" + for key, value in serialised_params.params.items() + ) columns.append( f"{col._meta.name} = {col.__class__.__name__}({params_string})" ) + extra_definitions.extend( + [ + i + for i in serialised_params.extra_definitions + if isinstance(i, SerialisedEnumTypeDefinition) + ] + ) + + for m2m_relationship in cls._meta.m2m_relationships: + joining_table_name = ( + m2m_relationship._meta.resolved_joining_table.__name__ + ) + columns.append( + f"{m2m_relationship._meta.name} = M2M({joining_table_name})" + ) + + extra_definitions_string = spacer.join( + [repr(i) for i in extra_definitions] + ) columns_string = spacer.join(columns) tablename = repr(cls._meta.tablename) @@ -780,17 +1498,19 @@ def _table_str(cls, abbreviated=False): else f"{parent_class_name}, tablename={tablename}" ) - return ( - f"class {cls.__name__}({class_args}):\n" f" {columns_string}\n" - ) + output = f"class {cls.__name__}({class_args}):\n" + if extra_definitions_string: + output += f" {extra_definitions_string}\n" + output += f" {columns_string}\n" + return output def create_table_class( class_name: str, - bases: t.Tuple[t.Type] = (Table,), - class_kwargs: t.Dict[str, t.Any] = {}, - class_members: t.Dict[str, t.Any] = {}, -) -> t.Type[Table]: + bases: tuple[type] = (Table,), + class_kwargs: dict[str, Any] = {}, + class_members: dict[str, Any] = {}, +) -> type[Table]: """ Used to dynamically create ``Table``subclasses at runtime. Most users will not require this. It's mostly used internally for Piccolo's @@ -806,9 +1526,211 @@ def create_table_class( For example, `{'my_column': Varchar()}`. """ - return types.new_class( - name=class_name, - bases=bases, - kwds=class_kwargs, - exec_body=lambda namespace: namespace.update(class_members), + return cast( + type[Table], + types.new_class( + name=class_name, + bases=bases, + kwds=class_kwargs, + exec_body=lambda namespace: namespace.update(class_members), + ), + ) + + +############################################################################### +# Quickly create or drop database tables from Piccolo `Table` classes. + + +async def create_db_tables( + *tables: type[Table], if_not_exists: bool = False +) -> None: + """ + Creates the database table for each ``Table`` class passed in. The tables + are created in the correct order, based on their foreign keys. + + :param tables: + The tables to create in the database. + :param if_not_exists: + No errors will be raised if any of the tables already exist in the + database. + + """ + if tables: + engine = tables[0]._meta.db + else: + return + + sorted_table_classes = sort_table_classes(list(tables)) + + atomic = engine.atomic() + atomic.add( + *[ + table.create_table(if_not_exists=if_not_exists) + for table in sorted_table_classes + ] + ) + await atomic.run() + + +def create_db_tables_sync( + *tables: type[Table], if_not_exists: bool = False +) -> None: + """ + A sync wrapper around :func:`create_db_tables`. + """ + run_sync(create_db_tables(*tables, if_not_exists=if_not_exists)) + + +def create_tables(*tables: type[Table], if_not_exists: bool = False) -> None: + """ + This original implementation has been replaced, because it was synchronous, + and felt at odds with the rest of the Piccolo codebase which is async + first. + + Instead, use create_db_tables for asynchronous code, or + create_db_tables_sync for synchronous code + """ + colored_warning( + "`create_tables` is deprecated and will be removed in v1 of Piccolo. " + "Use `await create_db_tables(...)` or `create_db_tables_sync(...)` " + "instead.", + category=DeprecationWarning, + ) + + return create_db_tables_sync(*tables, if_not_exists=if_not_exists) + + +async def drop_db_tables(*tables: type[Table]) -> None: + """ + Drops the database table for each ``Table`` class passed in. The tables + are dropped in the correct order, based on their foreign keys. + + :param tables: + The tables to delete from the database. + + """ + if tables: + engine = tables[0]._meta.db + else: + return + + if engine.engine_type == "sqlite": + # SQLite doesn't support CASCADE, so we have to drop them in the + # correct order. + sorted_table_classes = reversed(sort_table_classes(list(tables))) + ddl_statements = [ + Alter(table=table).drop_table(if_exists=True) + for table in sorted_table_classes + ] + else: + ddl_statements = [ + table.alter().drop_table(cascade=True, if_exists=True) + for table in tables + ] + + atomic = engine.atomic() + atomic.add(*ddl_statements) + await atomic.run() + + +def drop_db_tables_sync(*tables: type[Table]) -> None: + """ + A sync wrapper around :func:`drop_db_tables`. + """ + run_sync(drop_db_tables(*tables)) + + +def drop_tables(*tables: type[Table]) -> None: + """ + This original implementation has been replaced, because it was synchronous, + and felt at odds with the rest of the Piccolo codebase which is async + first. + + Instead, use drop_db_tables for asynchronous code, or + drop_db_tables_sync for synchronous code + """ + colored_warning( + "`drop_tables` is deprecated and will be removed in v1 of Piccolo. " + "Use `await drop_db_tables(...)` or `drop_db_tables_sync(...)` " + "instead.", + category=DeprecationWarning, ) + + return drop_db_tables_sync(*tables) + + +############################################################################### + + +def sort_table_classes( + table_classes: list[type[Table]], +) -> list[type[Table]]: + """ + Sort the table classes based on their foreign keys, so they can be created + in the correct order. + """ + table_class_dict = { + table_class._meta.tablename: table_class + for table_class in table_classes + } + + graph = _get_graph(table_classes) + + sorter = TopologicalSorter(graph) + ordered_tablenames = tuple(sorter.static_order()) + + output: list[type[Table]] = [] + for tablename in ordered_tablenames: + table_class = table_class_dict.get(tablename) + if table_class is not None: + output.append(table_class) + + return output + + +def _get_graph( + table_classes: list[type[Table]], + iterations: int = 0, + max_iterations: int = 5, +) -> dict[str, set[str]]: + """ + Analyses the tables based on their foreign keys, and returns a data + structure like: + + .. code-block:: python + + {'band': {'manager'}, 'concert': {'band', 'venue'}, 'manager': set()} + + The keys are tablenames, and the values are tablenames directly connected + to it via a foreign key. + + """ + output: dict[str, set[str]] = {} + + if iterations >= max_iterations: + return output + + for table_class in table_classes: + dependents: set[str] = set() + for fk in table_class._meta.foreign_key_columns: + referenced_table = fk._foreign_key_meta.resolved_references + + if referenced_table._meta.tablename == table_class._meta.tablename: + # Most like a recursive link (using ForeignKey('self')). + continue + + dependents.add(referenced_table._meta.tablename) + + # We also recursively check the related tables to get a fuller + # picture of the schema and relationships. + if referenced_table._meta.tablename not in output: + output.update( + _get_graph( + [referenced_table], + iterations=iterations + 1, + ) + ) + + output[table_class._meta.tablename] = dependents + + return output diff --git a/piccolo/table_reflection.py b/piccolo/table_reflection.py new file mode 100644 index 000000000..52618761b --- /dev/null +++ b/piccolo/table_reflection.py @@ -0,0 +1,239 @@ +""" +This is an advanced Piccolo feature which allows runtime reflection of database +tables. +""" + +import asyncio +from dataclasses import dataclass +from typing import Any, Optional, Union + +from piccolo.apps.schema.commands.generate import get_output_schema +from piccolo.engine import engine_finder +from piccolo.engine.base import Engine +from piccolo.table import Table + + +class Immutable(object): + def _immutable(self, *args, **kwargs) -> TypeError: + raise TypeError("%s object is immutable" % self.__class__.__name__) + + __delitem__ = __setitem__ = __setattr__ = _immutable # type: ignore + + +class ImmutableDict(Immutable, dict): # type: ignore + """A dictionary that is not publicly mutable.""" + + clear = pop = popitem = setdefault = update = Immutable._immutable # type: ignore # noqa: E501 + + def __new__(cls, *args): + return dict.__new__(cls) + + def copy(self): + raise NotImplementedError( + "an immutabledict shouldn't need to be copied. use dict(d) " + "if you need a mutable dictionary." + ) + + def __reduce__(self): + return ImmutableDict, (dict(self),) + + def _insert_item(self, key, value) -> None: + """ + insert an item into the dictionary directly. + """ + dict.__setitem__(self, key, value) + + def _delete_item(self, key) -> None: + """ + Delete an item from dictionary directly. + """ + dict.__delitem__(self, key) + + def __repr__(self): + return f"ImmutableDict({dict.__repr__(self)})" + + +class Singleton(type): + """ + A metaclass that creates a Singleton base class when called. + """ + + _instances: dict = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__( + *args, **kwargs + ) + return cls._instances[cls] + + +@dataclass +class TableNameDetail: + name: str = "" + schema: str = "" + + +class TableStorage(metaclass=Singleton): + """ + A singleton object to store and access reflected tables. Currently it just + works with Postgres. + """ + + def __init__(self, engine: Optional[Engine] = None): + """ + :param engine: + Which engine to use to make the database queries. If not specified, + we try importing an engine from ``piccolo_conf.py``. + + """ + self.engine = engine or engine_finder() + self.tables = ImmutableDict() + self._schema_tables: dict[str, list[str]] = {} + + async def reflect( + self, + schema_name: str = "public", + include: Union[list[str], str, None] = None, + exclude: Union[list[str], str, None] = None, + keep_existing: bool = False, + ) -> None: + """ + Imports tables from the database into ``Table`` objects without + hard-coding them. + + If a table has a reference to another table, the referenced table will + be imported too. Reflection can have a performance impact based on the + number of tables. + + If you want to reflect your whole database, make sure to only do it + once or use the provided parameters instead of reflecting the whole + database every time. + + :param schema_name: + Name of the schema you want to reflect. + :param include: + It will only reflect the specified tables. Can be a list of tables + or a single table. + :param exclude: + It won't reflect the specified tables. Can be a list of tables or + a single table. + :param keep_existing: + If True, it will exclude the available tables and reflects the + currently unavailable ones. Default is False. + :returns: + None + + """ + include_list = self._to_list(include) + exclude_list = self._to_list(exclude) + + if keep_existing: + exclude_list += self._schema_tables.get(schema_name, []) + + output_schema = await get_output_schema( + schema_name=schema_name, + include=include_list, + exclude=exclude_list, + engine=self.engine, + ) + add_tables = [ + self._add_table(schema_name=schema_name, table=table) + for table in output_schema.tables + ] + await asyncio.gather(*add_tables) + + def clear(self) -> None: + """ + Removes all the tables within ``TableStorage``. + + :returns: + None + + """ + dict.clear(self.tables) + self._schema_tables.clear() + + async def get_table(self, tablename: str) -> Optional[type[Table]]: + """ + Returns the ``Table`` class if it exists. If the table is not present + in ``TableStorage``, it will try to reflect it. + + :param tablename: + The name of the table, schema name included. If the schema is + public, it's not necessary. For example: "public.manager" or + "manager", "test_schema.test_table". + :returns: + Table | None + + """ + table_class = self.tables.get(tablename) + if table_class is None: + tableNameDetail = self._get_schema_and_table_name(tablename) + await self.reflect( + schema_name=tableNameDetail.schema, + include=[tableNameDetail.name], + ) + table_class = self.tables.get(tablename) + return table_class + + async def _add_table(self, schema_name: str, table: type[Table]) -> None: + if issubclass(table, Table): + table_name = self._get_table_name( + table._meta.tablename, schema_name + ) + self.tables._insert_item(table_name, table) + self._add_to_schema_tables( + schema_name=schema_name, table_name=table._meta.tablename + ) + + def _add_to_schema_tables(self, schema_name: str, table_name: str) -> None: + """ + We keep a record of schemas and their tables for easy use. This method + adds a table to its schema. + + """ + schema_tables = self._schema_tables.get(schema_name) + if schema_tables is None: + self._schema_tables[schema_name] = [] + else: + self._schema_tables[schema_name].append(table_name) + + @staticmethod + def _get_table_name(name: str, schema: str): + return name if schema == "public" else f"{schema}.{name}" + + def __repr__(self): + return f"{[tablename for tablename, _ in self.tables.items()]}" + + @staticmethod + def _get_schema_and_table_name(tablename: str) -> TableNameDetail: + """ + Extract schema name and table name from full name of the table. + + :param tablename: + The full name of the table. + :returns: + Returns the name of the schema and the table. + + """ + tablename_list = tablename.split(".") + if len(tablename_list) == 2: + return TableNameDetail( + name=tablename_list[1], schema=tablename_list[0] + ) + + elif len(tablename_list) == 1: + return TableNameDetail(name=tablename_list[0], schema="public") + else: + raise ValueError("Couldn't find schema name.") + + @staticmethod + def _to_list(value: Any) -> list: + if isinstance(value, list): + return value + elif isinstance(value, (tuple, set)): + return list(value) + elif isinstance(value, str): + return [value] + return [] diff --git a/piccolo/testing/__init__.py b/piccolo/testing/__init__.py new file mode 100644 index 000000000..3cce94e1f --- /dev/null +++ b/piccolo/testing/__init__.py @@ -0,0 +1,3 @@ +from piccolo.testing.model_builder import ModelBuilder + +__all__ = ["ModelBuilder"] diff --git a/piccolo/testing/model_builder.py b/piccolo/testing/model_builder.py new file mode 100644 index 000000000..7e25d7e10 --- /dev/null +++ b/piccolo/testing/model_builder.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import datetime +import decimal +import json +from collections.abc import Callable +from decimal import Decimal +from typing import Any, Optional, Union, cast +from uuid import UUID + +from piccolo.columns import JSON, JSONB, Array, Column, ForeignKey +from piccolo.custom_types import TableInstance +from piccolo.testing.random_builder import RandomBuilder +from piccolo.utils.sync import run_sync + + +class ModelBuilder: + __DEFAULT_MAPPER: dict[type, Callable] = { + bool: RandomBuilder.next_bool, + bytes: RandomBuilder.next_bytes, + datetime.date: RandomBuilder.next_date, + datetime.datetime: RandomBuilder.next_datetime, + float: RandomBuilder.next_float, + decimal.Decimal: RandomBuilder.next_decimal, + int: RandomBuilder.next_int, + str: RandomBuilder.next_str, + datetime.time: RandomBuilder.next_time, + datetime.timedelta: RandomBuilder.next_timedelta, + UUID: RandomBuilder.next_uuid, + } + + @classmethod + async def build( + cls, + table_class: type[TableInstance], + defaults: Optional[dict[Union[Column, str], Any]] = None, + persist: bool = True, + minimal: bool = False, + ) -> TableInstance: + """ + Build a ``Table`` instance with random data and save async. + If the ``Table`` has any foreign keys, then the related rows are also + created automatically. + + :param table_class: + Table class to randomize. + :param defaults: + Any values specified here will be used instead of random values. + :param persist: + Whether to save the new instance in the database. + :param minimal: + If ``True`` then any columns with ``null=True`` are assigned + a value of ``None``. + + Examples:: + + # Create a new instance with all random values: + manager = await ModelBuilder.build(Manager) + + # Create a new instance, with certain defaults: + manager = await ModelBuilder.build( + Manager, + {Manager.name: 'Guido'} + ) + + # Create a new instance, but don't save it in the database: + manager = await ModelBuilder.build(Manager, persist=False) + + # Create a new instance, with all null values set to None: + manager = await ModelBuilder.build(Manager, minimal=True) + + # We can pass other table instances in as default values: + band = await ModelBuilder.build(Band, {Band.manager: manager}) + + """ + return await cls._build( + table_class=table_class, + defaults=defaults, + persist=persist, + minimal=minimal, + ) + + @classmethod + def build_sync( + cls, + table_class: type[TableInstance], + defaults: Optional[dict[Union[Column, str], Any]] = None, + persist: bool = True, + minimal: bool = False, + ) -> TableInstance: + """ + A sync wrapper around :meth:`build`. + """ + return run_sync( + cls.build( + table_class=table_class, + defaults=defaults, + persist=persist, + minimal=minimal, + ) + ) + + @classmethod + async def _build( + cls, + table_class: type[TableInstance], + defaults: Optional[dict[Union[Column, str], Any]] = None, + minimal: bool = False, + persist: bool = True, + ) -> TableInstance: + model = table_class(_ignore_missing=True) + defaults = {} if not defaults else defaults + + for column, value in defaults.items(): + if isinstance(column, str): + column = model._meta.get_column_by_name(column) + + setattr(model, column._meta.name, value) + + for column in model._meta.columns: + if column._meta.null and minimal: + continue + + if column._meta.name in defaults: + continue # Column value exists + + if isinstance(column, ForeignKey) and persist: + # Check for recursion + if column._foreign_key_meta.references is table_class: + if column._meta.null is True: + # We can avoid this problem entirely by setting it to + # None. + random_value = None + else: + # There's no way to avoid recursion in the situation. + raise ValueError("Recursive foreign key detected") + else: + reference_model = await cls._build( + column._foreign_key_meta.resolved_references, + persist=True, + ) + random_value = getattr( + reference_model, + reference_model._meta.primary_key._meta.name, + ) + else: + random_value = cls._randomize_attribute(column) + + setattr(model, column._meta.name, random_value) + + if persist: + await model.save().run() + + return model + + @classmethod + def _randomize_attribute(cls, column: Column) -> Any: + """ + Generate a random value for a column and apply formatting. + + :param column: + Column class to randomize. + + """ + random_value: Any + if column.value_type == Decimal: + precision, scale = column._meta.params["digits"] or (4, 2) + random_value = RandomBuilder.next_decimal( + precision=precision, scale=scale + ) + elif column.value_type == datetime.datetime: + tz_aware = getattr(column, "tz_aware", False) + random_value = RandomBuilder.next_datetime(tz_aware=tz_aware) + elif column.value_type == list: + length = RandomBuilder.next_int(maximum=10) + if column._meta.choices: + random_value = [ + RandomBuilder.next_enum(column._meta.choices) + for _ in range(length) + ] + else: + base_type = cast(Array, column).base_column.value_type + random_value = [ + cls.__DEFAULT_MAPPER[base_type]() for _ in range(length) + ] + elif column._meta.choices: + random_value = RandomBuilder.next_enum(column._meta.choices) + else: + random_value = cls.__DEFAULT_MAPPER[column.value_type]() + + if "length" in column._meta.params and isinstance(random_value, str): + return random_value[: column._meta.params["length"]] + elif isinstance(column, (JSON, JSONB)): + return json.dumps({"value": random_value}) + + return random_value diff --git a/piccolo/testing/random_builder.py b/piccolo/testing/random_builder.py new file mode 100644 index 000000000..537cb8113 --- /dev/null +++ b/piccolo/testing/random_builder.py @@ -0,0 +1,86 @@ +import datetime +import decimal +import enum +import random +import string +import uuid +from typing import Any + + +class RandomBuilder: + @classmethod + def next_bool(cls) -> bool: + return random.choice([True, False]) + + @classmethod + def next_bytes(cls, length=8) -> bytes: + return random.getrandbits(length * 8).to_bytes(length, "little") + + @classmethod + def next_date(cls) -> datetime.date: + return datetime.date( + year=random.randint(2000, 2050), + month=random.randint(1, 12), + day=random.randint(1, 28), + ) + + @classmethod + def next_datetime(cls, tz_aware: bool = False) -> datetime.datetime: + return datetime.datetime( + year=random.randint(2000, 2050), + month=random.randint(1, 12), + day=random.randint(1, 28), + hour=random.randint(0, 23), + minute=random.randint(0, 59), + second=random.randint(0, 59), + tzinfo=datetime.timezone.utc if tz_aware else None, + ) + + @classmethod + def next_enum(cls, e: type[enum.Enum]) -> Any: + return random.choice([item.value for item in e]) + + @classmethod + def next_float(cls, minimum=0, maximum=2147483647, scale=5) -> float: + return round(random.uniform(minimum, maximum), scale) + + @classmethod + def next_decimal( + cls, precision: int = 4, scale: int = 2 + ) -> decimal.Decimal: + # For precision 4 and scale 2, maximum needs to be 99.99. + maximum = (10 ** (precision - scale)) - (10 ** (-1 * scale)) + float_number = cls.next_float(maximum=maximum, scale=scale) + # We convert float_number to a string first, otherwise the decimal + # value is slightly off due to floating point precision. + return decimal.Decimal(str(float_number)) + + @classmethod + def next_int(cls, minimum=0, maximum=2147483647) -> int: + return random.randint(minimum, maximum) + + @classmethod + def next_str(cls, length=16) -> str: + return "".join( + random.choice(string.ascii_letters) for _ in range(length) + ) + + @classmethod + def next_time(cls) -> datetime.time: + return datetime.time( + hour=random.randint(0, 23), + minute=random.randint(0, 59), + second=random.randint(0, 59), + ) + + @classmethod + def next_timedelta(cls) -> datetime.timedelta: + return datetime.timedelta( + days=random.randint(1, 7), + hours=random.randint(1, 23), + minutes=random.randint(0, 59), + ) + + @classmethod + def next_uuid(cls) -> uuid.UUID: + return uuid.uuid4() diff --git a/piccolo/testing/test_case.py b/piccolo/testing/test_case.py new file mode 100644 index 000000000..08bd61a5c --- /dev/null +++ b/piccolo/testing/test_case.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from typing import Optional +from unittest import IsolatedAsyncioTestCase, TestCase + +from piccolo.engine import Engine, engine_finder +from piccolo.table import ( + Table, + create_db_tables, + create_db_tables_sync, + drop_db_tables, + drop_db_tables_sync, +) + + +class TableTest(TestCase): + """ + Identical to :class:`AsyncTableTest `, + except it only work for sync tests. Only use this if you can't make your + tests async (perhaps you're on Python 3.7 where ``IsolatedAsyncioTestCase`` + isn't available). + + For example:: + + class TestBand(TableTest): + tables = [Band] + + def test_band(self): + ... + + """ # noqa: E501 + + tables: list[type[Table]] + + def setUp(self) -> None: + create_db_tables_sync(*self.tables) + + def tearDown(self) -> None: + drop_db_tables_sync(*self.tables) + + +class AsyncTableTest(IsolatedAsyncioTestCase): + """ + Used for tests where we need to create Piccolo tables - they will + automatically be created and dropped. + + For example:: + + class TestBand(AsyncTableTest): + tables = [Band] + + async def test_band(self): + ... + + """ + + tables: list[type[Table]] + + async def asyncSetUp(self) -> None: + await create_db_tables(*self.tables) + + async def asyncTearDown(self) -> None: + await drop_db_tables(*self.tables) + + +class AsyncTransactionTest(IsolatedAsyncioTestCase): + """ + Wraps each test in a transaction, which is automatically rolled back when + the test finishes. + + .. warning:: + Python 3.11 and above only. + + If your test suite just contains ``AsyncTransactionTest`` tests, then you + can setup your database tables once before your test suite runs. Any + changes made to your tables by the tests will be rolled back automatically. + + Here's an example:: + + from piccolo.testing.test_case import AsyncTransactionTest + + + class TestBandEndpoint(AsyncTransactionTest): + + async def test_band_response(self): + \"\"\" + Make sure the endpoint returns a 200. + \"\"\" + band = Band({Band.name: "Pythonistas"}) + await band.save() + + # Using an API testing client, like httpx: + response = await client.get(f"/bands/{band.id}/") + self.assertEqual(response.status_code, 200) + + We add a ``Band`` to the database, but any subsequent tests won't see it, + as the changes are rolled back automatically. + + """ + + # We use `engine_finder` to find the current `Engine`, but you can + # explicity set it here if you prefer: + # + # class MyTest(AsyncTransactionTest): + # db = DB + # + # ... + # + db: Optional[Engine] = None + + async def asyncSetUp(self) -> None: + db = self.db or engine_finder() + assert db is not None + self.transaction = db.transaction() + # This is only available in Python 3.11 and above: + await self.enterAsyncContext(cm=self.transaction) # type: ignore + + async def asyncTearDown(self): + await super().asyncTearDown() + await self.transaction.rollback() diff --git a/piccolo/utils/dictionary.py b/piccolo/utils/dictionary.py new file mode 100644 index 000000000..880950711 --- /dev/null +++ b/piccolo/utils/dictionary.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import Any + + +def make_nested(dictionary: dict[str, Any]) -> dict[str, Any]: + """ + Rows are returned from the database as a flat dictionary, with keys such + as ``'manager.name'`` if the column belongs to a related table. + + This function puts any values from a related table into a sub dictionary. + + .. code-block:: python + + response = await Band.select(Band.name, Band.manager.name) + >>> print(response) + [{'name': 'Pythonistas', 'band.name': 'Guido'}] + + >>> make_nested(response[0]) + {'name': 'Pythonistas', 'band': {'name': 'Guido'}} + + """ + output: dict[str, Any] = {} + + items = list(dictionary.items()) + items.sort(key=lambda x: x[0]) + + for key, value in items: + path = key.split(".") + if len(path) == 1: + output[path[0]] = value + else: + # Force the root element to be an empty dictionary, if it's some + # other value (most likely an integer). This is because there are + # situations where a query can have `band` and `band.id`. + # For example: + # await Band.select( + # Band.all_columns(), + # Band.manager.all_columns() + # ).run() + # In this situation nesting takes precendence. + root = output.get(path[0], None) + if isinstance(root, dict): + dictionary = root + else: + dictionary = {} + output[path[0]] = dictionary + + for path_element in path[1:-1]: + root = dictionary.setdefault(path_element, {}) + if not isinstance(root, dict): + root = {} + dictionary[path_element] = root + dictionary = root + + dictionary[path[-1]] = value + + return output diff --git a/piccolo/utils/encoding.py b/piccolo/utils/encoding.py index e84b1b015..3c637079e 100644 --- a/piccolo/utils/encoding.py +++ b/piccolo/utils/encoding.py @@ -1,6 +1,6 @@ from __future__ import annotations -import typing as t +from typing import Any try: import orjson @@ -12,15 +12,63 @@ ORJSON = False -def dump_json(data: t.Any) -> str: +def dump_json(data: Any, pretty: bool = False) -> str: if ORJSON: - return orjson.dumps(data, default=str).decode("utf8") + orjson_params: dict[str, Any] = {"default": str} + if pretty: + orjson_params["option"] = ( + orjson.OPT_INDENT_2 | orjson.OPT_APPEND_NEWLINE # type: ignore + ) + return orjson.dumps(data, **orjson_params).decode( # type: ignore + "utf8" + ) else: - return json.dumps(data, default=str) + params: dict[str, Any] = {"default": str} + if pretty: + params["indent"] = 2 + return json.dumps(data, **params) # type: ignore -def load_json(data: str) -> t.Any: - if ORJSON: - return orjson.loads(data) - else: - return json.loads(data) +class JSONDict(dict): + """ + Once we have parsed a JSON string into a dictionary, we can't distinguish + it from other dictionaries. + + Sometimes we might want to - for example:: + + >>> await Album.select( + ... Album.all_columns(), + ... Album.recording_studio.all_columns() + ... ).output( + ... nested=True, + ... load_json=True + ... ) + + [{ + 'id': 1, + 'band': 1, + 'name': 'Awesome album 1', + 'recorded_at': { + 'id': 1, + 'facilities': {'restaurant': True, 'mixing_desk': True}, + 'name': 'Abbey Road' + }, + 'release_date': datetime.date(2021, 1, 1) + }] + + Facilities could be mistaken for a table. + + """ + + ... + + +def load_json(data: str) -> Any: + response = ( + orjson.loads(data) if ORJSON else json.loads(data) # type: ignore + ) + + if isinstance(response, dict): + return JSONDict(**response) + + return response diff --git a/piccolo/utils/lazy_loader.py b/piccolo/utils/lazy_loader.py index 51469b06e..7b64a896f 100644 --- a/piccolo/utils/lazy_loader.py +++ b/piccolo/utils/lazy_loader.py @@ -3,7 +3,7 @@ import importlib import types -import typing as t +from typing import Any class LazyLoader(types.ModuleType): @@ -39,19 +39,19 @@ def _load(self) -> types.ModuleType: raise ModuleNotFoundError( "PostgreSQL driver not found. " "Try running `pip install 'piccolo[postgres]'`" - ) + ) from exc elif str(exc) == "No module named 'aiosqlite'": raise ModuleNotFoundError( "SQLite driver not found. " "Try running `pip install 'piccolo[sqlite]'`" - ) + ) from exc else: - raise exc + raise exc from exc - def __getattr__(self, item) -> t.Any: + def __getattr__(self, item) -> Any: module = self._load() return getattr(module, item) - def __dir__(self) -> t.List[str]: + def __dir__(self) -> list[str]: module = self._load() return dir(module) diff --git a/piccolo/utils/list.py b/piccolo/utils/list.py new file mode 100644 index 000000000..8ec6aa066 --- /dev/null +++ b/piccolo/utils/list.py @@ -0,0 +1,60 @@ +from collections.abc import Sequence +from typing import TypeVar, Union + +ElementType = TypeVar("ElementType") + + +def flatten( + items: Sequence[Union[ElementType, list[ElementType]]] +) -> list[ElementType]: + """ + Takes a sequence of elements, and flattens it out. For example:: + + >>> flatten(['a', ['b', 'c']]) + ['a', 'b', 'c'] + + We need this for situations like this:: + + await Band.select(Band.name, Band.manager.all_columns()) + + """ + _items: list[ElementType] = [] + for item in items: + if isinstance(item, list): + _items.extend(item) + else: + _items.append(item) + + return _items + + +def batch(data: list[ElementType], chunk_size: int) -> list[list[ElementType]]: + """ + Breaks the list down into sublists of the given ``chunk_size``. The last + sublist may have fewer elements than ``chunk_size``:: + + >>> batch(['a', 'b', 'c'], 2) + [['a', 'b'], ['c']] + + :param data: + The data to break up into sublists. + :param chunk_size: + How large each sublist should be. + + """ + # TODO: Replace with itertools.batch when available: + # https://docs.python.org/3.12/library/itertools.html#itertools.batched + + if chunk_size <= 0: + raise ValueError("`chunk_size` must be greater than 0.") + + row_count = len(data) + + iterations = int(row_count / chunk_size) + if row_count % chunk_size > 0: + iterations += 1 + + return [ + data[(i * chunk_size) : ((i + 1) * chunk_size)] # noqa: E203 + for i in range(0, iterations) + ] diff --git a/piccolo/utils/naming.py b/piccolo/utils/naming.py index 62cd03b53..efc04cc2c 100644 --- a/piccolo/utils/naming.py +++ b/piccolo/utils/naming.py @@ -6,3 +6,10 @@ def _camel_to_snake(string: str): Convert CamelCase to snake_case. """ return inflection.underscore(string) + + +def _snake_to_camel(string: str): + """ + Convert snake_case to CamelCase. + """ + return inflection.camelize(string) diff --git a/piccolo/utils/objects.py b/piccolo/utils/objects.py new file mode 100644 index 000000000..22b48a530 --- /dev/null +++ b/piccolo/utils/objects.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from piccolo.columns.column_types import ForeignKey + +if TYPE_CHECKING: # pragma: no cover + from piccolo.table import Table + + +def make_nested_object(row: dict[str, Any], table_class: type[Table]) -> Table: + """ + Takes a nested dictionary such as this: + + .. code-block:: python + + row = { + 'id': 1, + 'name': 'Pythonistas', + 'manager': {'id': 1, 'name': 'Guido'} + } + + And returns a ``Table`` instance, with nested table instances for related + tables. + + For example: + + .. code-block:: python + + band = make_nested(row, Band) + >>> band + + >>> band.manager + + >>> band.manager.id + 1 + + """ + table_params: dict[str, Any] = {} + + for key, value in row.items(): + if isinstance(value, dict): + # This is probably a related table. + fk_column = table_class._meta.get_column_by_name(key) + + if isinstance(fk_column, ForeignKey): + related_table_class = ( + fk_column._foreign_key_meta.resolved_references + ) + table_params[key] = make_nested_object( + value, related_table_class + ) + else: + # The value doesn't belong to a foreign key, so just append it. + table_params[key] = value + + else: + table_params[key] = value + + table_instance = table_class(**table_params) + table_instance._exists_in_db = True + return table_instance diff --git a/piccolo/utils/printing.py b/piccolo/utils/printing.py index 5de6626ff..d1a0bd538 100644 --- a/piccolo/utils/printing.py +++ b/piccolo/utils/printing.py @@ -1,11 +1,57 @@ -def get_fixed_length_string(string: str, length=20) -> str: +from typing import List + + +def get_fixed_length_string(string: str, length: int = 20) -> str: """ - Add spacing to the end of the string so it's a fixed length. + Add spacing to the end of the string so it's a fixed length, or truncate + if it's too long. """ if len(string) > length: - fixed_length_string = string[: length - 3] + "..." - else: - spacing = "".join([" " for i in range(length - len(string))]) - fixed_length_string = f"{string}{spacing}" + return f"{string[: length - 3]}..." + spacing = "".join(" " for _ in range(length - len(string))) + return f"{string}{spacing}" + + +def print_heading(string: str, width: int = 64) -> None: + """ + Prints out a nicely formatted heading to the console. Useful for breaking + up the output in large CLI commands. + """ + print(f"\n{string.upper():^{width}}") + print("-" * width) + + +def print_dict_table(data: List[dict], header_separator: bool = False) -> None: + """ + Prints out a list of dictionaries in tabular form. + + Uses the first list element to extract the column names and their order + within the row. + + """ + if len(data) == 0: + raise ValueError("The data must have at least one element.") + + column_names = data[0].keys() + widths = {column_name: len(column_name) for column_name in column_names} + + for item in data: + for column in column_names: + width = len(str(item[column])) + if width > widths[column]: + widths[column] = width + + format_string = " | ".join(f"{{:<{widths[w]}}}" for w in column_names) + + print(format_string.format(*[str(w) for w in column_names])) + + if header_separator: + format_string_sep = "-+-".join( + [f"{{:<{widths[w]}}}" for w in column_names] + ) + print( + format_string_sep.format(*["-" * widths[w] for w in column_names]) + ) - return fixed_length_string + for item in data: + print(format_string.format(*[str(item[w]) for w in column_names])) diff --git a/piccolo/utils/pydantic.py b/piccolo/utils/pydantic.py new file mode 100644 index 000000000..794406f17 --- /dev/null +++ b/piccolo/utils/pydantic.py @@ -0,0 +1,366 @@ +from __future__ import annotations + +import itertools +import json +from collections import defaultdict +from collections.abc import Callable +from functools import partial +from typing import Any, Optional, Union + +import pydantic + +from piccolo.columns import Column +from piccolo.columns.column_types import ( + JSON, + JSONB, + Array, + Decimal, + Email, + ForeignKey, + Numeric, + Text, + Timestamptz, + Varchar, +) +from piccolo.table import Table +from piccolo.utils.encoding import load_json + +try: + from pydantic.config import JsonDict +except ImportError: + JsonDict = dict # type: ignore + + +def pydantic_json_validator(value: Optional[str], required: bool = True): + if value is None: + if required: + raise ValueError("The JSON value wasn't provided.") + else: + return value + + try: + load_json(value) + except json.JSONDecodeError as e: + raise ValueError("Unable to parse the JSON.") from e + else: + return value + + +def is_table_column(column: Column, table: type[Table]) -> bool: + """ + Verify that the given ``Column`` belongs to the given ``Table``. + """ + if column._meta.table is table: + return True + elif ( + column._meta.call_chain + and column._meta.call_chain[0]._meta.table is table + ): + # We also allow the column if it's joined from the table. + return True + return False + + +def validate_columns(columns: tuple[Column, ...], table: type[Table]) -> bool: + """ + Verify that each column is a ``Column``` instance, and its parent is the + given ``Table``. + """ + return all( + isinstance(column, Column) + and is_table_column(column=column, table=table) + for column in columns + ) + + +def get_array_value_type(column: Array, inner: Optional[type] = None) -> type: + """ + Gets the correct type for an ``Array`` column (which might be + multidimensional). + """ + if isinstance(column.base_column, Array): + inner_type = get_array_value_type(column.base_column, inner=inner) + else: + inner_type = get_pydantic_value_type(column.base_column) + + return list[inner_type] # type: ignore + + +def get_pydantic_value_type(column: Column) -> type: + """ + Map the Piccolo ``Column`` to a Pydantic type. + """ + value_type: type + + if isinstance(column, (Decimal, Numeric)): + value_type = pydantic.condecimal( + max_digits=column.precision, decimal_places=column.scale + ) + elif isinstance(column, Email): + value_type = pydantic.EmailStr # type: ignore + elif isinstance(column, Varchar): + value_type = pydantic.constr(max_length=column.length) + elif isinstance(column, Array): + value_type = get_array_value_type(column=column) + else: + value_type = column.value_type + + return value_type + + +def create_pydantic_model( + table: type[Table], + nested: Union[bool, tuple[ForeignKey, ...]] = False, + exclude_columns: tuple[Column, ...] = (), + include_columns: tuple[Column, ...] = (), + include_default_columns: bool = False, + include_readable: bool = False, + all_optional: bool = False, + model_name: Optional[str] = None, + deserialize_json: bool = False, + recursion_depth: int = 0, + max_recursion_depth: int = 5, + pydantic_config: Optional[pydantic.config.ConfigDict] = None, + json_schema_extra: Optional[dict[str, Any]] = None, +) -> type[pydantic.BaseModel]: + """ + Create a Pydantic model representing a table. + + :param table: + The Piccolo ``Table`` you want to create a Pydantic serialiser model + for. + :param nested: + Whether ``ForeignKey`` columns are converted to nested Pydantic models. + If ``False``, none are converted. If ``True``, they all are converted. + If a tuple of ``ForeignKey`` columns is passed in, then only those are + converted. + :param exclude_columns: + A tuple of ``Column`` instances that should be excluded from the + Pydantic model. Only specify ``include_columns`` or + ``exclude_columns``. + :param include_columns: + A tuple of ``Column`` instances that should be included in the + Pydantic model. Only specify ``include_columns`` or + ``exclude_columns``. + :param include_default_columns: + Whether to include columns like ``id`` in the serialiser. You will + typically include these columns in GET requests, but don't require + them in POST requests. + :param include_readable: + Whether to include 'readable' columns, which give a string + representation of a foreign key. + :param all_optional: + If True, all fields are optional. Useful for filters etc. + :param model_name: + By default, the classname of the Piccolo ``Table`` will be used, but + you can override it if you want multiple Pydantic models based off the + same Piccolo table. + :param deserialize_json: + By default, the values of any Piccolo ``JSON`` or ``JSONB`` columns are + returned as strings. By setting this parameter to ``True``, they will + be returned as objects. + :param recursion_depth: + Not to be set by the user - used internally to track recursion. + :param max_recursion_depth: + If using nested models, this specifies the max amount of recursion. + :param pydantic_config: + Allows you to configure some of Pydantic's behaviour. See the + `Pydantic docs `_ + for more info. + :param json_schema_extra: + This can be used to add additional fields to the schema. This is + very useful when using Pydantic's JSON Schema features. For example: + + .. code-block:: python + + >>> my_model = create_pydantic_model(Band, my_extra_field="Hello") + >>> my_model.model_json_schema() + {..., "my_extra_field": "Hello"} + + :returns: + A Pydantic model. + + """ # noqa: E501 + if exclude_columns and include_columns: + raise ValueError( + "`include_columns` and `exclude_columns` can't be used at the " + "same time." + ) + + if recursion_depth == 0: + if exclude_columns: + if not validate_columns(columns=exclude_columns, table=table): + raise ValueError( + f"`exclude_columns` are invalid: {exclude_columns!r}" + ) + + if include_columns: + if not validate_columns(columns=include_columns, table=table): + raise ValueError( + f"`include_columns` are invalid: {include_columns!r}" + ) + + ########################################################################### + + columns: dict[str, Any] = {} + validators: dict[str, Callable] = {} + + piccolo_columns = tuple( + table._meta.columns + if include_default_columns + else table._meta.non_default_columns + ) + + if include_columns: + include_columns_plus_ancestors = list( + itertools.chain( + include_columns, *[i._meta.call_chain for i in include_columns] + ) + ) + piccolo_columns = tuple( + i + for i in piccolo_columns + if any( + i._equals(include_column) + for include_column in include_columns_plus_ancestors + ) + ) + + if exclude_columns: + piccolo_columns = tuple( + i + for i in piccolo_columns + if not any( + i._equals(exclude_column) for exclude_column in exclude_columns + ) + ) + + model_name = model_name or table.__name__ + + for column in piccolo_columns: + column_name = column._meta.name + + is_optional = True if all_optional else not column._meta.required + + ####################################################################### + # Work out the column type + + if isinstance(column, (JSON, JSONB)): + if deserialize_json: + value_type = pydantic.Json + else: + value_type = column.value_type + validator = partial( + pydantic_json_validator, required=not is_optional + ) + validators[ + f"{column_name}_is_json" + ] = pydantic.field_validator(column_name)( + validator # type: ignore + ) + else: + value_type = get_pydantic_value_type(column=column) + + _type = Optional[value_type] if is_optional else value_type + + ####################################################################### + + params: dict[str, Any] = {} + if is_optional: + params["default"] = None + + if column._meta.db_column_name != column._meta.name: + params["alias"] = column._meta.db_column_name + + extra: JsonDict = { + "help_text": column._meta.help_text, + "choices": column._meta.get_choices_dict(), + "secret": column._meta.secret, + "nullable": column._meta.null, + "unique": column._meta.unique, + } + + if isinstance(column, ForeignKey): + if recursion_depth < max_recursion_depth and ( + (nested is True) + or ( + isinstance(nested, tuple) + and any( + column._equals(i) + for i in itertools.chain( + nested, *[i._meta.call_chain for i in nested] + ) + ) + ) + ): + nested_model_name = f"{model_name}.{column._meta.name}" + _type = create_pydantic_model( + table=column._foreign_key_meta.resolved_references, + nested=nested, + include_columns=include_columns, + exclude_columns=exclude_columns, + include_default_columns=include_default_columns, + include_readable=include_readable, + all_optional=all_optional, + deserialize_json=deserialize_json, + recursion_depth=recursion_depth + 1, + max_recursion_depth=max_recursion_depth, + model_name=nested_model_name, + ) + + tablename = ( + column._foreign_key_meta.resolved_references._meta.tablename + ) + target_column = ( + column._foreign_key_meta.resolved_target_column._meta.name + ) + extra["foreign_key"] = { + "to": tablename, + "target_column": target_column, + } + + if include_readable: + columns[f"{column_name}_readable"] = (str, None) + else: + # This is used to tell Piccolo Admin that we want to display these + # values using a specific widget. + if isinstance(column, Text): + extra["widget"] = "text-area" + elif isinstance(column, (JSON, JSONB)): + extra["widget"] = "json" + elif isinstance(column, Timestamptz): + extra["widget"] = "timestamptz" + + # It is useful for Piccolo API and Piccolo Admin to easily know + # how many dimensions the array has. + if isinstance(column, Array): + extra["dimensions"] = column._get_dimensions() + + field = pydantic.Field( + json_schema_extra={"extra": extra}, + **params, + ) + + columns[column_name] = (_type, field) + + pydantic_config = ( + pydantic_config.copy() + if pydantic_config + else pydantic.config.ConfigDict() + ) + pydantic_config["arbitrary_types_allowed"] = True + + json_schema_extra_ = defaultdict(dict, **(json_schema_extra or {})) + json_schema_extra_["extra"]["help_text"] = table._meta.help_text + + pydantic_config["json_schema_extra"] = dict(json_schema_extra_) + + model = pydantic.create_model( + model_name, + __config__=pydantic_config, + __validators__=validators, + **columns, + ) + model.__qualname__ = model_name + + return model diff --git a/piccolo/utils/repr.py b/piccolo/utils/repr.py index 2e8ed8c52..d3750c104 100644 --- a/piccolo/utils/repr.py +++ b/piccolo/utils/repr.py @@ -20,6 +20,7 @@ def repr_class_instance(class_instance: object) -> str: args_dict[arg_name] = value args_str = ", ".join( - [f"{key}={value.__repr__()}" for key, value in args_dict.items()] + f"{key}={value.__repr__()}" for key, value in args_dict.items() ) + return f"{class_instance.__class__.__name__}({args_str})" diff --git a/piccolo/utils/sql_values.py b/piccolo/utils/sql_values.py index cb57afcfc..4d44c96f5 100644 --- a/piccolo/utils/sql_values.py +++ b/piccolo/utils/sql_values.py @@ -1,28 +1,50 @@ from __future__ import annotations -import typing as t +import functools from enum import Enum +from typing import TYPE_CHECKING, Any from piccolo.utils.encoding import dump_json +from piccolo.utils.warnings import colored_warning -if t.TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from piccolo.columns import Column -def convert_to_sql_value(value: t.Any, column: Column) -> t.Any: +def convert_to_sql_value(value: Any, column: Column) -> Any: """ Some values which can be passed into Piccolo queries aren't valid in the database. For example, Enums, Table instances, and dictionaries for JSON columns. """ - from piccolo.columns.column_types import JSON, JSONB + from piccolo.columns.column_types import JSON, JSONB, ForeignKey from piccolo.table import Table if isinstance(value, Table): - return getattr(value, value._meta.primary_key._meta.name) + if isinstance(column, ForeignKey): + return getattr( + value, + column._foreign_key_meta.resolved_target_column._meta.name, + ) + elif column._meta.primary_key: + return getattr(value, column._meta.name) + else: + raise ValueError( + "Table instance provided, and the column isn't a ForeignKey, " + "or primary key column." + ) elif isinstance(value, Enum): return value.value elif isinstance(column, (JSON, JSONB)) and not isinstance(value, str): - return dump_json(value) + return None if value is None else dump_json(value) + elif isinstance(value, list): + if len(value) > 100: + colored_warning( + "When using large lists, consider bypassing the ORM and " + "using SQL directly for improved performance." + ) + # Attempt to do this as performantly as possible. + func = functools.partial(convert_to_sql_value, column=column) + return list(map(func, value)) else: return value diff --git a/piccolo/utils/sync.py b/piccolo/utils/sync.py index dbed89e14..62aea2a38 100644 --- a/piccolo/utils/sync.py +++ b/piccolo/utils/sync.py @@ -1,33 +1,29 @@ from __future__ import annotations import asyncio -import typing as t -from concurrent.futures import ThreadPoolExecutor +from collections.abc import Coroutine +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Any, TypeVar +ReturnType = TypeVar("ReturnType") -def run_sync(coroutine: t.Coroutine): + +def run_sync( + coroutine: Coroutine[Any, Any, ReturnType], +) -> ReturnType: """ Run the coroutine synchronously - trying to accommodate as many edge cases as possible. - 1. When called within a coroutine. - 2. When called from `python -m asyncio`, or iPython with %autoawait + 2. When called from ``python -m asyncio``, or iPython with %autoawait enabled, which means an event loop may already be running in the current thread. - """ try: - loop = asyncio.get_event_loop() - except RuntimeError: + # We try this first, as in most situations this will work. return asyncio.run(coroutine) - else: - if loop.is_running(): - new_loop = asyncio.new_event_loop() - - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit( - new_loop.run_until_complete, coroutine - ) - return future.result() - else: - return loop.run_until_complete(coroutine) + except RuntimeError: + # An event loop already exists. + with ThreadPoolExecutor(max_workers=1) as executor: + future: Future = executor.submit(asyncio.run, coroutine) + return future.result() diff --git a/piccolo/utils/warnings.py b/piccolo/utils/warnings.py index ef1fcba90..f5f36852f 100644 --- a/piccolo/utils/warnings.py +++ b/piccolo/utils/warnings.py @@ -1,6 +1,5 @@ from __future__ import annotations -import typing as t import warnings from enum import Enum @@ -21,7 +20,7 @@ def colored_string(message: str, level: Level = Level.medium) -> str: def colored_warning( message: str, - category: t.Type[Warning] = Warning, + category: type[Warning] = Warning, stacklevel: int = 3, level: Level = Level.medium, ): diff --git a/piccolo_conf.py b/piccolo_conf.py index a020d97b4..6ece2f685 100644 --- a/piccolo_conf.py +++ b/piccolo_conf.py @@ -11,7 +11,6 @@ from piccolo.conf.apps import AppRegistry from piccolo.engine.postgres import PostgresEngine - DB = PostgresEngine(config={}) diff --git a/profiling/README.md b/profiling/README.md new file mode 100644 index 000000000..22e6eb852 --- /dev/null +++ b/profiling/README.md @@ -0,0 +1,5 @@ +# Profiling + +Tests we run to evaluate Piccolo performance. + +You need to setup a local Postgres database called 'piccolo_profile'. diff --git a/profiling/__init__.py b/profiling/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/profiling/run_profile.py b/profiling/run_profile.py new file mode 100644 index 000000000..9dc53e081 --- /dev/null +++ b/profiling/run_profile.py @@ -0,0 +1,40 @@ +import asyncio + +from viztracer import VizTracer + +from piccolo.columns.column_types import Varchar +from piccolo.engine.postgres import PostgresEngine +from piccolo.table import Table + +DB = PostgresEngine(config={"database": "piccolo_profile"}) + + +class Band(Table, db=DB): + name = Varchar() + + +async def setup(): + await Band.alter().drop_table(if_exists=True) + await Band.create_table(if_not_exists=True) + await Band.insert(*[Band(name="test") for _ in range(1000)]) + + +class Trace: + def __enter__(self): + self.tracer = VizTracer(log_async=True) + self.tracer.start() + + def __exit__(self, *args): + self.tracer.stop() + self.tracer.save() + + +async def run_queries(): + await setup() + + with Trace(): + await Band.select() + + +if __name__ == "__main__": + asyncio.run(run_queries()) diff --git a/pyproject.toml b/pyproject.toml index b489446a7..0c94768d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,47 @@ [tool.black] line-length = 79 -target-version = ['py37', 'py38', 'py39'] +target-version = ['py38', 'py39', 'py310'] [tool.isort] profile = "black" line_length = 79 + +[tool.mypy] +[[tool.mypy.overrides]] +module = [ + "asyncpg.*", + "colorama", + "dateutil", + "IPython", + "IPython.core.interactiveshell", + "jinja2", + "orjson", + "aiosqlite", + "uvicorn" +] +ignore_missing_imports = true + + +[tool.pytest.ini_options] +markers = [ + "integration", + "speed", + "cockroach_array_slow" +] + +[tool.coverage.run] +omit = [ + "*.jinja", + "**/piccolo_migrations/*", + "**/piccolo_app.py", + "**/utils/graphlib/*", +] + +[tool.coverage.report] +# Note, we have to re-specify "pragma: no cover" +# https://coverage.readthedocs.io/en/6.3.3/excluding.html#advanced-exclusion +exclude_lines = [ + "raise NotImplementedError", + "pragma: no cover", + "pass" +] diff --git a/requirements/README.md b/requirements/README.md index 07faa574f..235d5e4c9 100644 --- a/requirements/README.md +++ b/requirements/README.md @@ -1,6 +1,8 @@ # Requirement files -* `extras` - Optional dependencies of `Piccolo`. -* `dev-requirements.txt` - Requirements needed to develop `Piccolo`. -* `requirements.txt` - Default requirements of `Piccolo`. -* `test-requirements.txt` - Requirements needed to run `Piccolo` tests. +- `extras` - optional dependencies of `Piccolo`. +- `dev-requirements.txt` - needed to develop `Piccolo`. +- `requirements.txt` - default requirements of `Piccolo`. +- `test-requirements.txt` - needed to run `Piccolo` tests. +- `doc-requirements.txt` - needed to run the `Piccolo` docs +- `readthedocs-requirements.txt` - just used by ReadTheDocs. diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt index c4a733841..6807bf875 100644 --- a/requirements/dev-requirements.txt +++ b/requirements/dev-requirements.txt @@ -1,9 +1,11 @@ -black>=21.7b0 -ipdb==0.12.2 -ipython==7.8.0 -flake8==3.8.4 -isort==5.9.2 -twine==3.1.1 -mypy==0.782 +black==24.3.0 +ipdb==0.13.9 +ipython>=7.31.1 +flake8==6.1.0 +isort==5.10.1 +slotscheck==0.17.1 +twine==3.8.0 +mypy==1.18.1 pip-upgrader==1.4.15 -wheel==0.36.2 +pyright==1.1.367 +wheel==0.38.1 diff --git a/requirements/doc-requirements.txt b/requirements/doc-requirements.txt new file mode 100644 index 000000000..525aded5c --- /dev/null +++ b/requirements/doc-requirements.txt @@ -0,0 +1,3 @@ +Sphinx==8.3.0 +piccolo-theme==0.24.0 +sphinx-autobuild==2025.8.25 diff --git a/requirements/extras/orjson.txt b/requirements/extras/orjson.txt index b21219704..31eb4cfa7 100644 --- a/requirements/extras/orjson.txt +++ b/requirements/extras/orjson.txt @@ -1 +1 @@ -orjson==3.4.1 +orjson>=3.5.1 diff --git a/requirements/extras/postgres.txt b/requirements/extras/postgres.txt index 864747fa4..1b54800e6 100644 --- a/requirements/extras/postgres.txt +++ b/requirements/extras/postgres.txt @@ -1 +1 @@ -asyncpg>=0.21.0 +asyncpg>=0.30.0 diff --git a/requirements/extras/pytest.txt b/requirements/extras/pytest.txt new file mode 100644 index 000000000..e079f8a60 --- /dev/null +++ b/requirements/extras/pytest.txt @@ -0,0 +1 @@ +pytest diff --git a/requirements/extras/uvloop.txt b/requirements/extras/uvloop.txt index a51c94482..8bb45b44c 100644 --- a/requirements/extras/uvloop.txt +++ b/requirements/extras/uvloop.txt @@ -1 +1 @@ -uvloop>=0.12.0 +uvloop>=0.12.0; sys_platform != "win32" diff --git a/requirements/profile-requirements.txt b/requirements/profile-requirements.txt new file mode 100644 index 000000000..7c0fca7d8 --- /dev/null +++ b/requirements/profile-requirements.txt @@ -0,0 +1 @@ +viztracer==0.15.0 diff --git a/requirements/readthedocs-requirements.txt b/requirements/readthedocs-requirements.txt new file mode 100644 index 000000000..8db118e9b --- /dev/null +++ b/requirements/readthedocs-requirements.txt @@ -0,0 +1,2 @@ +-r requirements.txt +-r doc-requirements.txt diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 06c21db8a..0a5ee6244 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,5 +1,7 @@ black colorama>=0.4.0 Jinja2>=2.11.0 -targ>=0.3.3 +targ>=0.3.7 inflection>=0.5.1 +typing-extensions>=4.3.0 +pydantic[email]==2.* diff --git a/requirements/test-requirements.txt b/requirements/test-requirements.txt index 5f744295f..f419596cf 100644 --- a/requirements/test-requirements.txt +++ b/requirements/test-requirements.txt @@ -1,4 +1,6 @@ -coveralls==2.2.0 -pytest-cov==2.10.1 -pytest==6.2.1 -python-dateutil==2.8.1 +coveralls==3.3.1 +httpx==0.28.0 +pytest-cov==3.0.0 +pytest==8.3.5 +python-dateutil==2.8.2 +typing-extensions>=4.3.0 diff --git a/scripts/README.md b/scripts/README.md index 10ba8b7b6..6ffbafa09 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -4,8 +4,10 @@ The scripts follow GitHub's ["Scripts to Rule Them All"](https://github.com/gith Call them from the root of the project, e.g. `./scripts/lint.sh`. -* `scripts/lint.sh` - Run the automated code linting/formatting tools. -* `scripts/piccolo.sh` - Run the Piccolo CLI on the example project in the `tests` folder. -* `scripts/release.sh` - Publish package to PyPI. -* `scripts/test-postgres.sh` - Run the test suite with Postgres. -* `scripts/test-sqlite.sh` - Run the test suite with SQLite. +- `scripts/format.sh` - Format the code to the required standards. +- `scripts/lint.sh` - Run the automated code linting/formatting tools. +- `scripts/piccolo.sh` - Run the Piccolo CLI on the example project in the `tests` folder. +- `scripts/profile.sh` - Run a profiler to test performance. +- `scripts/release.sh` - Publish package to PyPI. +- `scripts/test-postgres.sh` - Run the test suite with Postgres. +- `scripts/test-sqlite.sh` - Run the test suite with SQLite. diff --git a/scripts/format.sh b/scripts/format.sh new file mode 100755 index 000000000..482e0a21d --- /dev/null +++ b/scripts/format.sh @@ -0,0 +1,9 @@ +#!/bin/bash +SOURCES="piccolo tests" + +echo "Running isort..." +isort $SOURCES +echo "-----" + +echo "Running black..." +black $SOURCES diff --git a/scripts/lint.sh b/scripts/lint.sh index 5326db07c..14582f903 100755 --- a/scripts/lint.sh +++ b/scripts/lint.sh @@ -1,8 +1,27 @@ #!/bin/bash +set -e -SOURCES="piccolo tests" +MODULES="piccolo" +SOURCES="$MODULES tests" -isort $SOURCES -black $SOURCES +echo "Running isort..." +isort --check $SOURCES +echo "-----" + +echo "Running black..." +black --check $SOURCES +echo "-----" + +echo "Running flake8..." flake8 $SOURCES +echo "-----" + +echo "Running mypy..." mypy $SOURCES +echo "-----" + +echo "Running slotscheck..." +python -m slotscheck $MODULES +echo "-----" + +echo "All passed!" diff --git a/scripts/profile.sh b/scripts/profile.sh new file mode 100755 index 000000000..5396f2bcf --- /dev/null +++ b/scripts/profile.sh @@ -0,0 +1,2 @@ +#!/bin/bash +python -m profiling.run_profile && vizviewer result.json diff --git a/scripts/pyright.sh b/scripts/pyright.sh new file mode 100755 index 000000000..616652eb8 --- /dev/null +++ b/scripts/pyright.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# We have a separate script for pyright vs lint.sh, as it's hard to get 100% +# success in pyright. In the future we might merge them. + +set -e + +MODULES="piccolo" +SOURCES="$MODULES tests" + +echo "Running pyright..." +pyright $sources +echo "-----" + +echo "All passed!" diff --git a/scripts/run-docs.sh b/scripts/run-docs.sh new file mode 100755 index 000000000..bac9a6ebe --- /dev/null +++ b/scripts/run-docs.sh @@ -0,0 +1,2 @@ +#!/bin/bash +sphinx-autobuild -a docs/src docs/build/html --watch piccolo diff --git a/scripts/test-cockroach.sh b/scripts/test-cockroach.sh new file mode 100755 index 000000000..12b102448 --- /dev/null +++ b/scripts/test-cockroach.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# To run all in a folder tests/ +# To run all in a file tests/test_foo.py +# To run all in a class tests/test_foo.py::TestFoo +# To run a single test tests/test_foo.py::TestFoo::test_foo + +export PICCOLO_CONF="tests.cockroach_conf" +python -m pytest \ + --cov=piccolo \ + --cov-report=xml \ + --cov-report=html \ + --cov-fail-under=80 \ + -m "not integration and not cockroach_array_slow" \ + -s $@ diff --git a/scripts/test-integration.sh b/scripts/test-integration.sh new file mode 100755 index 000000000..41afc1823 --- /dev/null +++ b/scripts/test-integration.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# To run all in a folder tests/ +# To run all in a file tests/test_foo.py +# To run all in a class tests/test_foo.py::TestFoo +# To run a single test tests/test_foo.py::TestFoo::test_foo + +export PICCOLO_CONF="tests.postgres_conf" +python -m pytest \ + -m integration \ + -s $@ diff --git a/scripts/test-postgres.sh b/scripts/test-postgres.sh index 0c59964f0..9f853b734 100755 --- a/scripts/test-postgres.sh +++ b/scripts/test-postgres.sh @@ -5,4 +5,10 @@ # To run a single test tests/test_foo.py::TestFoo::test_foo export PICCOLO_CONF="tests.postgres_conf" -python -m pytest --cov=piccolo --cov-report xml --cov-report html --cov-fail-under 85 -s $@ +python -m pytest \ + --cov=piccolo \ + --cov-report=xml \ + --cov-report=html \ + --cov-fail-under=85 \ + -m "not integration" \ + -s $@ diff --git a/scripts/test-sqlite.sh b/scripts/test-sqlite.sh index e14ba4aa4..fa53de8bc 100755 --- a/scripts/test-sqlite.sh +++ b/scripts/test-sqlite.sh @@ -5,4 +5,10 @@ # To run a single test tests/test_foo.py::TestFoo::test_foo export PICCOLO_CONF="tests.sqlite_conf" -python -m pytest --cov=piccolo --cov-report xml --cov-report html --cov-fail-under 75 -s $@ +python -m pytest \ + --cov=piccolo \ + --cov-report=xml \ + --cov-report=html \ + --cov-fail-under=75 \ + -m "not integration" \ + -s $@ diff --git a/scripts/test-strict.sh b/scripts/test-strict.sh new file mode 100755 index 000000000..08cc48a2b --- /dev/null +++ b/scripts/test-strict.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# This runs the tests in Python's development mode: +# https://docs.python.org/3/library/devmode.html +# It shows us deprecation warnings, and asyncio warnings. + +# To run all in a folder tests/ +# To run all in a file tests/test_foo.py +# To run all in a class tests/test_foo.py::TestFoo +# To run a single test tests/test_foo.py::TestFoo::test_foo + +export PICCOLO_CONF="tests.postgres_conf" +python -X dev -m pytest -m "not integration" -s $@ diff --git a/setup.py b/setup.py index 51d642c5b..2f2f320e0 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,13 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import itertools import os -from typing import List + from setuptools import find_packages, setup from piccolo import __VERSION__ as VERSION - directory = os.path.abspath(os.path.dirname(__file__)) extras = ["orjson", "playground", "postgres", "sqlite", "uvloop"] @@ -17,7 +17,7 @@ LONG_DESCRIPTION = f.read() -def parse_requirement(req_path: str) -> List[str]: +def parse_requirement(req_path: str) -> list[str]: """ Parse requirement file. Example: @@ -25,21 +25,24 @@ def parse_requirement(req_path: str) -> List[str]: parse_requirement('extras/playground.txt') # requirements/extras/playground.txt Returns: List[str]: list of requirements specified in the file. - """ + """ # noqa: E501 with open(os.path.join(directory, "requirements", req_path)) as f: contents = f.read() return [i.strip() for i in contents.strip().split("\n")] -def extras_require(): +def extras_require() -> dict[str, list[str]]: """ Parse requirements in requirements/extras directory """ - extra_requirements = {} - for extra in extras: - extra_requirements[extra] = parse_requirement( - os.path.join("extras", extra + ".txt") - ) + extra_requirements = { + extra: parse_requirement(os.path.join("extras", f"{extra}.txt")) + for extra in extras + } + + extra_requirements["all"] = list( + itertools.chain.from_iterable(extra_requirements.values()) + ) return extra_requirements @@ -54,7 +57,7 @@ def extras_require(): long_description_content_type="text/markdown", author="Daniel Townsend", author_email="dan@dantownsend.co.uk", - python_requires=">=3.7.0", + python_requires=">=3.10.0", url="https://github.com/piccolo-orm/piccolo", packages=find_packages(exclude=("tests",)), package_data={ @@ -66,6 +69,13 @@ def extras_require(): ], "piccolo": ["py.typed"], }, + project_urls={ + "Documentation": ( + "https://piccolo-orm.readthedocs.io/en/latest/index.html" + ), + "Source": "https://github.com/piccolo-orm/piccolo", + "Tracker": "https://github.com/piccolo-orm/piccolo/issues", + }, install_requires=parse_requirement("requirements.txt"), extras_require=extras_require(), license="MIT", @@ -73,9 +83,11 @@ def extras_require(): "License :: OSI Approved :: MIT License", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Programming Language :: Python :: Implementation :: CPython", "Framework :: AsyncIO", "Typing :: Typed", diff --git a/tests/apps/app/commands/test_new.py b/tests/apps/app/commands/test_new.py index efdacdb5f..bcd89a045 100644 --- a/tests/apps/app/commands/test_new.py +++ b/tests/apps/app/commands/test_new.py @@ -3,7 +3,12 @@ import tempfile from unittest import TestCase -from piccolo.apps.app.commands.new import module_exists, new +from piccolo.apps.app.commands.new import ( + get_app_module, + module_exists, + new, + validate_app_name, +) class TestModuleExists(TestCase): @@ -39,5 +44,41 @@ def test_new_with_clashing_name(self): exception = context.exception self.assertTrue( - exception.code.startswith("A module called sys already exists") + str(exception.code).startswith( + "A module called sys already exists" + ) ) + + +class TestValidateAppName(TestCase): + + def test_validate_app_name(self): + """ + Make sure only app names which work as valid Python package names are + allowed. + """ + # Should be rejected: + for app_name in ("MY APP", "app/my_app", "my.app"): + with self.assertRaises(ValueError): + validate_app_name(app_name=app_name) + + # Should work fine: + validate_app_name(app_name="music") + + +class TestGetAppIdentifier(TestCase): + + def test_get_app_module(self): + """ + Make sure the the ``root`` argument is handled correctly. + """ + self.assertEqual( + get_app_module(app_name="music", root="."), + "music.piccolo_app", + ) + + for root in ("apps", "./apps", "./apps/"): + self.assertEqual( + get_app_module(app_name="music", root=root), + "apps.music.piccolo_app", + ) diff --git a/tests/apps/app/commands/test_show_all.py b/tests/apps/app/commands/test_show_all.py index 0800c042d..92404a412 100644 --- a/tests/apps/app/commands/test_show_all.py +++ b/tests/apps/app/commands/test_show_all.py @@ -11,5 +11,9 @@ def test_show_all(self, print_: MagicMock): self.assertEqual( print_.mock_calls, - [call("Registered apps:"), call("tests.example_app.piccolo_app")], + [ + call("Registered apps:"), + call("tests.example_apps.music.piccolo_app"), + call("tests.example_apps.mega.piccolo_app"), + ], ) diff --git a/tests/apps/asgi/commands/files/dummy_server.py b/tests/apps/asgi/commands/files/dummy_server.py new file mode 100644 index 000000000..709a98a77 --- /dev/null +++ b/tests/apps/asgi/commands/files/dummy_server.py @@ -0,0 +1,42 @@ +import asyncio +import importlib +import sys +from collections.abc import Callable +from typing import Union, cast + +from httpx import ASGITransport, AsyncClient +from uvicorn import Config, Server + + +async def dummy_server(app: Union[str, Callable] = "app:app") -> None: + """ + A very simplistic ASGI server. It's used to run the generated ASGI + applications in unit tests. + + :param app: + Either an ASGI app, or a string representing the path to an ASGI app. + For example, ``module_1.app:app`` which would import an ASGI app called + ``app`` from ``module_1.app``. + + """ + print("Running dummy server ...") + + if isinstance(app, str): + path, app_name = app.rsplit(":") + module = importlib.import_module(path) + app = cast(Callable, getattr(module, app_name)) + + try: + async with AsyncClient(transport=ASGITransport(app=app)) as client: + response = await client.get("http://localhost:8000") + if response.status_code != 200: + sys.exit("The app isn't callable!") + except Exception: + config = Config(app=app) + server = Server(config=config) + asyncio.create_task(server.serve()) + await asyncio.sleep(0.1) + + +if __name__ == "__main__": + asyncio.run(dummy_server()) diff --git a/tests/apps/asgi/commands/test_new.py b/tests/apps/asgi/commands/test_new.py index 68dcc84b0..fa4a99cab 100644 --- a/tests/apps/asgi/commands/test_new.py +++ b/tests/apps/asgi/commands/test_new.py @@ -1,28 +1,95 @@ +import ast import os import shutil +import subprocess import tempfile +from pathlib import Path from unittest import TestCase from unittest.mock import patch +import pytest + from piccolo.apps.asgi.commands.new import ROUTERS, SERVERS, new +from tests.base import unix_only class TestNewApp(TestCase): - @patch( - "piccolo.apps.asgi.commands.new.get_routing_framework", - return_value=ROUTERS[0], - ) - @patch( - "piccolo.apps.asgi.commands.new.get_server", - return_value=SERVERS[0], - ) - def test_new(self, *args, **kwargs): - root = os.path.join(tempfile.gettempdir(), "asgi_app") - - if os.path.exists(root): - shutil.rmtree(root) - - os.mkdir(root) - new(root=root) - - self.assertTrue(os.path.exists(os.path.join(root, "app.py"))) + def test_new(self): + """ + Test that the created files have the correct content. List all .py + files inside the root directory and check if they are valid python code + with ast.parse. + """ + for router in ROUTERS: + for server in SERVERS: + with patch( + "piccolo.apps.asgi.commands.new.get_routing_framework", + return_value=router, + ), patch( + "piccolo.apps.asgi.commands.new.get_server", + return_value=server, + ): + root = os.path.join(tempfile.gettempdir(), "asgi_app") + + if os.path.exists(root): + shutil.rmtree(root) + + os.mkdir(root) + new(root=root) + + # Make sure the files were created + self.assertTrue( + os.path.exists(os.path.join(root, "app.py")) + ) + + # Make sure the Python code is valid. + for file in list(Path(root).rglob("*.py")): + with open(os.path.join(root, file), "r") as f: + ast.parse(f.read()) + f.close() + + +class TestNewAppRuns(TestCase): + @unix_only + @pytest.mark.integration + def test_new(self): + """ + Test that the ASGI app actually runs. + """ + for router in ROUTERS: + with patch( + "piccolo.apps.asgi.commands.new.get_routing_framework", + return_value=router, + ), patch( + "piccolo.apps.asgi.commands.new.get_server", + return_value=SERVERS[0], + ): + root = os.path.join(tempfile.gettempdir(), "asgi_app") + + if os.path.exists(root): + shutil.rmtree(root) + + os.mkdir(root) + new(root=root) + + # Copy a dummy ASGI server, so we can test that the server + # works. + shutil.copyfile( + os.path.join( + os.path.dirname(__file__), + "files", + "dummy_server.py", + ), + os.path.join(root, "dummy_server.py"), + ) + + response = subprocess.run( + f"cd {root} && " + "python -m venv venv && " + "./venv/bin/pip install -r requirements.txt && " + "./venv/bin/python dummy_server.py", + shell=True, + ) + self.assertEqual( + response.returncode, 0, msg=f"{router} failed" + ) diff --git a/tests/apps/fixtures/__init__.py b/tests/apps/fixtures/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/apps/fixtures/commands/__init__.py b/tests/apps/fixtures/commands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/apps/fixtures/commands/test_dump_load.py b/tests/apps/fixtures/commands/test_dump_load.py new file mode 100644 index 000000000..728f2f5c0 --- /dev/null +++ b/tests/apps/fixtures/commands/test_dump_load.py @@ -0,0 +1,277 @@ +import datetime +import decimal +import os +import tempfile +import uuid +from unittest import TestCase + +from piccolo.apps.fixtures.commands.dump import ( + FixtureConfig, + dump_to_json_string, +) +from piccolo.apps.fixtures.commands.load import load, load_json_string +from piccolo.utils.sync import run_sync +from tests.base import engines_only +from tests.example_apps.mega.tables import MegaTable, SmallTable + + +class TestDumpLoad(TestCase): + """ + Test the fixture dump and load commands - makes sense to test them + together. + """ + + maxDiff = None + + def setUp(self): + for table_class in (SmallTable, MegaTable): + table_class.create_table().run_sync() + + def tearDown(self): + for table_class in (MegaTable, SmallTable): + table_class.alter().drop_table().run_sync() + + def insert_rows(self): + small_table = SmallTable(varchar_col="Test") + small_table.save().run_sync() + + SmallTable(varchar_col="Test 2").save().run_sync() + + mega_table = MegaTable( + bigint_col=1, + boolean_col=True, + bytea_col="hello".encode("utf8"), + date_col=datetime.date(year=2021, month=1, day=1), + foreignkey_col=small_table, + integer_col=1, + interval_col=datetime.timedelta(seconds=10), + json_col={"a": 1}, + jsonb_col={"a": 1}, + numeric_col=decimal.Decimal("1.1"), + real_col=1.1, + double_precision_col=1.344, + smallint_col=1, + text_col="hello", + timestamp_col=datetime.datetime(year=2021, month=1, day=1), + timestamptz_col=datetime.datetime( + year=2021, month=1, day=1, tzinfo=datetime.timezone.utc + ), + uuid_col=uuid.UUID("12783854-c012-4c15-8183-8eecb46f2c4e"), + varchar_col="hello", + unique_col="hello", + null_col=None, + not_null_col="hello", + ) + mega_table.save().run_sync() + + def _run_comparison(self, table_class_names: list[str]): + self.insert_rows() + + json_string = run_sync( + dump_to_json_string( + fixture_configs=[ + FixtureConfig( + app_name="mega", + table_class_names=table_class_names, + ) + ] + ) + ) + + # We need to clear the data out now, otherwise when loading the data + # back in, there will be constraint errors over clashing primary + # keys. + SmallTable.delete(force=True).run_sync() + MegaTable.delete(force=True).run_sync() + + run_sync(load_json_string(json_string)) + + self.assertEqual( + SmallTable.select().run_sync(), + [ + {"id": 1, "varchar_col": "Test"}, + {"id": 2, "varchar_col": "Test 2"}, + ], + ) + + mega_table_data = MegaTable.select().run_sync() + + # Real numbers don't have perfect precision when coming back from the + # database, so we need to round them to be able to compare them. + mega_table_data[0]["real_col"] = round( + mega_table_data[0]["real_col"], 1 + ) + + # Remove white space from the JSON values + for col_name in ("json_col", "jsonb_col"): + mega_table_data[0][col_name] = mega_table_data[0][ + col_name + ].replace(" ", "") + + self.assertTrue(len(mega_table_data) == 1) + + self.assertDictEqual( + mega_table_data[0], + { + "id": 1, + "bigint_col": 1, + "boolean_col": True, + "bytea_col": b"hello", + "date_col": datetime.date(2021, 1, 1), + "foreignkey_col": 1, + "integer_col": 1, + "interval_col": datetime.timedelta(seconds=10), + "json_col": '{"a":1}', + "jsonb_col": '{"a":1}', + "numeric_col": decimal.Decimal("1.1"), + "real_col": 1.1, + "double_precision_col": 1.344, + "smallint_col": 1, + "text_col": "hello", + "timestamp_col": datetime.datetime(2021, 1, 1, 0, 0), + "timestamptz_col": datetime.datetime( + 2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc + ), + "uuid_col": uuid.UUID("12783854-c012-4c15-8183-8eecb46f2c4e"), + "varchar_col": "hello", + "unique_col": "hello", + "null_col": None, + "not_null_col": "hello", + }, + ) + + # Make sure subsequent inserts work. + SmallTable().save().run_sync() + + @engines_only("postgres", "sqlite") + def test_dump_load(self): + """ + Make sure we can dump some rows into a JSON fixture, then load them + back into the database. + """ + self._run_comparison(table_class_names=["SmallTable", "MegaTable"]) + + @engines_only("postgres", "sqlite") + def test_dump_load_ordering(self): + """ + Similar to `test_dump_load` - but we need to make sure it inserts + the data in the correct order, so foreign key constraints don't fail. + """ + self._run_comparison(table_class_names=["MegaTable", "SmallTable"]) + + @engines_only("cockroach") + def test_dump_load_cockroach(self): + """ + Similar to `test_dump_load`, except the schema is slightly different + for CockroachDB. + """ + self.insert_rows() + + json_string = run_sync( + dump_to_json_string( + fixture_configs=[ + FixtureConfig( + app_name="mega", + table_class_names=["SmallTable", "MegaTable"], + ) + ] + ) + ) + + # We need to clear the data out now, otherwise when loading the data + # back in, there will be constraint errors over clashing primary + # keys. + SmallTable.delete(force=True).run_sync() + MegaTable.delete(force=True).run_sync() + + run_sync(load_json_string(json_string)) + + result = SmallTable.select().run_sync()[0] + result.pop("id") + + self.assertDictEqual( + result, + {"varchar_col": "Test"}, + ) + + mega_table_data = MegaTable.select().run_sync() + + # Real numbers don't have perfect precision when coming back from the + # database, so we need to round them to be able to compare them. + mega_table_data[0]["real_col"] = round( + mega_table_data[0]["real_col"], 1 + ) + + # Remove white space from the JSON values + for col_name in ("json_col", "jsonb_col"): + mega_table_data[0][col_name] = mega_table_data[0][ + col_name + ].replace(" ", "") + + self.assertTrue(len(mega_table_data) == 1) + + mega_table_data = mega_table_data[0] + mega_table_data.pop("id") + mega_table_data.pop("foreignkey_col") + + self.assertDictEqual( + mega_table_data, + { + "bigint_col": 1, + "boolean_col": True, + "bytea_col": b"hello", + "date_col": datetime.date(2021, 1, 1), + "integer_col": 1, + "interval_col": datetime.timedelta(seconds=10), + "json_col": '{"a":1}', + "jsonb_col": '{"a":1}', + "numeric_col": decimal.Decimal("1.1"), + "real_col": 1.1, + "double_precision_col": 1.344, + "smallint_col": 1, + "text_col": "hello", + "timestamp_col": datetime.datetime(2021, 1, 1, 0, 0), + "timestamptz_col": datetime.datetime( + 2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc + ), + "uuid_col": uuid.UUID("12783854-c012-4c15-8183-8eecb46f2c4e"), + "varchar_col": "hello", + "unique_col": "hello", + "null_col": None, + "not_null_col": "hello", + }, + ) + + +class TestOnConflict(TestCase): + def setUp(self) -> None: + SmallTable.create_table().run_sync() + SmallTable({SmallTable.varchar_col: "Test"}).save().run_sync() + + def tearDown(self) -> None: + SmallTable.alter().drop_table().run_sync() + + def test_on_conflict(self): + temp_dir = tempfile.gettempdir() + + json_file_path = os.path.join(temp_dir, "fixture.json") + + json_string = run_sync( + dump_to_json_string( + fixture_configs=[ + FixtureConfig( + app_name="mega", + table_class_names=["SmallTable"], + ) + ] + ) + ) + + if os.path.exists(json_file_path): + os.unlink(json_file_path) + + with open(json_file_path, "w") as f: + f.write(json_string) + + run_sync(load(path=json_file_path, on_conflict="DO NOTHING")) + run_sync(load(path=json_file_path, on_conflict="DO UPDATE")) diff --git a/tests/apps/fixtures/commands/test_shared.py b/tests/apps/fixtures/commands/test_shared.py new file mode 100644 index 000000000..34e2af4ee --- /dev/null +++ b/tests/apps/fixtures/commands/test_shared.py @@ -0,0 +1,60 @@ +import datetime +import decimal +import uuid +from unittest import TestCase + +from piccolo.apps.fixtures.commands.shared import ( + FixtureConfig, + create_pydantic_fixture_model, +) + + +class TestShared(TestCase): + def test_shared(self): + pydantic_model = create_pydantic_fixture_model( + fixture_configs=[ + FixtureConfig( + app_name="mega", + table_class_names=["MegaTable", "SmallTable"], + ) + ] + ) + + data = { + "mega": { + "SmallTable": [{"id": 1, "varchar_col": "Test"}], + "MegaTable": [ + { + "id": 1, + "bigint_col": 1, + "boolean_col": True, + "bytea_col": b"hello", + "date_col": datetime.date(2021, 1, 1), + "foreignkey_col": 1, + "integer_col": 1, + "interval_col": datetime.timedelta(seconds=10), + "json_col": '{"a":1}', + "jsonb_col": '{"a": 1}', + "numeric_col": decimal.Decimal("1.10"), + "real_col": 1.100000023841858, + "smallint_col": 1, + "text_col": "hello", + "timestamp_col": datetime.datetime(2021, 1, 1, 0, 0), + "timestamptz_col": datetime.datetime( + 2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc + ), + "uuid_col": uuid.UUID( + "12783854-c012-4c15-8183-8eecb46f2c4e" + ), + "varchar_col": "hello", + "unique_col": "hello", + "null_col": None, + "not_null_col": "hello", + } + ], + } + } + + model = pydantic_model(**data) + self.assertEqual(model.mega.SmallTable[0].id, 1) # type: ignore + self.assertEqual(model.mega.MegaTable[0].id, 1) # type: ignore diff --git a/tests/apps/meta/__init__.py b/tests/apps/meta/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/apps/meta/commands/__init__.py b/tests/apps/meta/commands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/apps/meta/commands/test_version.py b/tests/apps/meta/commands/test_version.py new file mode 100644 index 000000000..723c4ff35 --- /dev/null +++ b/tests/apps/meta/commands/test_version.py @@ -0,0 +1,11 @@ +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from piccolo.apps.meta.commands.version import version + + +class TestVersion(TestCase): + @patch("piccolo.apps.meta.commands.version.print") + def test_version(self, print_: MagicMock): + version() + print_.assert_called_once() diff --git a/tests/apps/migrations/auto/integration/test_migrations.py b/tests/apps/migrations/auto/integration/test_migrations.py index 0cd336502..a065d9588 100644 --- a/tests/apps/migrations/auto/integration/test_migrations.py +++ b/tests/apps/migrations/auto/integration/test_migrations.py @@ -1,40 +1,62 @@ from __future__ import annotations import datetime +import decimal import os +import random import shutil import tempfile import time -import typing as t import uuid -from unittest import TestCase +from collections.abc import Callable +from typing import TYPE_CHECKING, Optional +from unittest.mock import MagicMock, patch +from piccolo.apps.migrations.auto.operations import RenameTable +from piccolo.apps.migrations.commands.backwards import ( + BackwardsMigrationManager, +) from piccolo.apps.migrations.commands.forwards import ForwardsMigrationManager from piccolo.apps.migrations.commands.new import ( _create_migrations_folder, _create_new_migration, ) from piccolo.apps.migrations.tables import Migration +from piccolo.apps.schema.commands.generate import RowMeta from piccolo.columns.column_types import ( + JSON, + JSONB, UUID, + Array, BigInt, + BigSerial, Boolean, Date, + Decimal, + DoublePrecision, + ForeignKey, Integer, Interval, + Numeric, + Real, + Serial, SmallInt, Text, Time, Timestamp, + Timestamptz, Varchar, ) from piccolo.columns.defaults.uuid import UUID4 +from piccolo.columns.m2m import M2M +from piccolo.columns.reference import LazyTableReference from piccolo.conf.apps import AppConfig -from piccolo.table import Table, create_table_class +from piccolo.schema import SchemaManager +from piccolo.table import Table, create_table_class, drop_db_tables_sync from piccolo.utils.sync import run_sync -from tests.base import postgres_only +from tests.base import DBTestCase, engines_only, engines_skip -if t.TYPE_CHECKING: +if TYPE_CHECKING: from piccolo.columns.base import Column @@ -70,62 +92,164 @@ def boolean_default(): return True -@postgres_only -class TestMigrations(TestCase): - def tearDown(self): - create_table_class("MyTable").alter().drop_table( - if_exists=True - ).run_sync() - Migration.alter().drop_table(if_exists=True).run_sync() +def numeric_default(): + return decimal.Decimal("1.2") + - def run_migrations(self, app_config: AppConfig): - manager = ForwardsMigrationManager(app_name=app_config.app_name) - run_sync(manager.create_migration_table()) - run_sync(manager.run_migrations(app_config=app_config)) +def array_default_integer(): + return [4, 5, 6] - def _test_migrations(self, table_classes: t.List[t.Type[Table]]): + +def array_default_varchar(): + return ["x", "y", "z"] + + +class MigrationTestCase(DBTestCase): + def _run_migrations(self, app_config: AppConfig): + forwards_manager = ForwardsMigrationManager( + app_name=app_config.app_name + ) + run_sync(forwards_manager.create_migration_table()) + run_sync(forwards_manager.run_migrations(app_config=app_config)) + + def _get_migrations_folder_path(self) -> str: temp_directory_path = tempfile.gettempdir() migrations_folder_path = os.path.join( temp_directory_path, "piccolo_migrations" ) + return migrations_folder_path + + def _get_app_config(self) -> AppConfig: + return AppConfig( + app_name="test_app", + migrations_folder_path=self._get_migrations_folder_path(), + table_classes=[], + ) + + def _test_migrations( + self, + table_snapshots: list[list[type[Table]]], + test_function: Optional[Callable[[RowMeta], bool]] = None, + ): + """ + Writes a migration file to disk and runs it. + + :param table_snapshots: + A list of lists. Each sub list represents a snapshot of the table + state. Migrations will be created and run based on each snapshot. + :param test_function: + After the migrations are run, this function is called. It is passed + a ``RowMeta`` instance which can be used to check the column was + created correctly in the database. It should return ``True`` if the + test passes, otherwise ``False``. + + """ + app_config = self._get_app_config() + + migrations_folder_path = app_config.resolved_migrations_folder_path if os.path.exists(migrations_folder_path): shutil.rmtree(migrations_folder_path) _create_migrations_folder(migrations_folder_path) - app_config = AppConfig( - app_name="test_app", - migrations_folder_path=migrations_folder_path, - table_classes=[], - ) - - for table_class in table_classes: - app_config.table_classes = [table_class] + for table_snapshot in table_snapshots: + app_config.table_classes = table_snapshot meta = run_sync( - _create_new_migration(app_config=app_config, auto=True) + _create_new_migration( + app_config=app_config, auto=True, auto_input="y" + ) ) self.assertTrue(os.path.exists(meta.migration_path)) - self.run_migrations(app_config=app_config) + self._run_migrations(app_config=app_config) # It's kind of absurd sleeping for 1 microsecond, but it guarantees # the migration IDs will be unique, and just in case computers # and / or Python get insanely fast in the future :) time.sleep(1e-6) - # TODO - check the migrations ran correctly + if test_function: + column = table_snapshots[-1][-1]._meta.non_default_columns[0] + column_name = column._meta.db_column_name + schema = column._meta.table._meta.schema + tablename = column._meta.table._meta.tablename + row_meta = self.get_postgres_column_definition( + tablename=tablename, + column_name=column_name, + schema=schema or "public", + ) + self.assertTrue( + test_function(row_meta), + msg=f"Meta is incorrect: {row_meta}", + ) + + def _get_migration_managers(self): + app_config = self._get_app_config() + + return run_sync( + ForwardsMigrationManager( + app_name=app_config.app_name + ).get_migration_managers(app_config=app_config) + ) + + def _run_backwards(self, migration_id: str): + """ + After running :meth:`_test_migrations`, if you call `_run_backwards` + then the migrations can be reversed. + + :param migration_id: + Which migration to reverse to. Can be: + + * A migration ID. + * A number, like ``1``, then it will reverse the most recent + migration. + * ``'all'`` then all of the migrations will be reversed. + + """ + migrations_folder_path = self._get_migrations_folder_path() + + app_config = AppConfig( + app_name="test_app", + migrations_folder_path=migrations_folder_path, + table_classes=[], + ) + + backwards_manager = BackwardsMigrationManager( + app_name=app_config.app_name, + migration_id=migration_id, + auto_agree=True, + ) + run_sync( + backwards_manager.run_migrations_backwards(app_config=app_config) + ) + + +@engines_only("postgres", "cockroach") +class TestMigrations(MigrationTestCase): + def setUp(self): + pass + + def tearDown(self): + create_table_class("MyTable").alter().drop_table( + if_exists=True + ).run_sync() + Migration.alter().drop_table(if_exists=True).run_sync() ########################################################################### def table(self, column: Column): + """ + A utility for creating Piccolo tables with the given column. + """ return create_table_class( class_name="MyTable", class_members={"my_column": column} ) + @engines_skip("cockroach") def test_varchar_column(self): self._test_migrations( - table_classes=[ - self.table(column) + table_snapshots=[ + [self.table(column)] for column in [ Varchar(), Varchar(length=100), @@ -136,13 +260,21 @@ def test_varchar_column(self): Varchar(index=True), Varchar(index=False), ] - ] + ], + test_function=lambda x: all( + [ + x.data_type == "character varying", + x.is_nullable == "NO", + x.column_default + in ("''::character varying", "'':::STRING"), + ] + ), ) def test_text_column(self): self._test_migrations( - table_classes=[ - self.table(column) + table_snapshots=[ + [self.table(column)] for column in [ Text(), Text(default="hello world"), @@ -152,13 +284,25 @@ def test_text_column(self): Text(index=True), Text(index=False), ] - ] + ], + test_function=lambda x: all( + [ + x.data_type == "text", + x.is_nullable == "NO", + x.column_default + in ( + "''", + "''::text", + "'':::STRING", + ), + ] + ), ) def test_integer_column(self): self._test_migrations( - table_classes=[ - self.table(column) + table_snapshots=[ + [self.table(column)] for column in [ Integer(), Integer(default=1), @@ -168,13 +312,64 @@ def test_integer_column(self): Integer(index=True), Integer(index=False), ] - ] + ], + test_function=lambda x: all( + [ + x.data_type in ("integer", "bigint"), # Cockroach DB. + x.is_nullable == "NO", + x.column_default in ("0", "0:::INT8"), # Cockroach DB. + ] + ), + ) + + def test_real_column(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Real(), + Real(default=1.1), + Real(null=True), + Real(null=False), + Real(index=True), + Real(index=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "real", + x.is_nullable == "NO", + x.column_default in ("0.0", "0.0:::FLOAT8"), + ] + ), + ) + + def test_double_precision_column(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + DoublePrecision(), + DoublePrecision(default=1.1), + DoublePrecision(null=True), + DoublePrecision(null=False), + DoublePrecision(index=True), + DoublePrecision(index=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "double precision", + x.is_nullable == "NO", + x.column_default in ("0.0", "0.0:::FLOAT8"), + ] + ), ) def test_smallint_column(self): self._test_migrations( - table_classes=[ - self.table(column) + table_snapshots=[ + [self.table(column)] for column in [ SmallInt(), SmallInt(default=1), @@ -184,13 +379,20 @@ def test_smallint_column(self): SmallInt(index=True), SmallInt(index=False), ] - ] + ], + test_function=lambda x: all( + [ + x.data_type == "smallint", + x.is_nullable == "NO", + x.column_default in ("0", "0:::INT8"), # Cockroach DB. + ] + ), ) def test_bigint_column(self): self._test_migrations( - table_classes=[ - self.table(column) + table_snapshots=[ + [self.table(column)] for column in [ BigInt(), BigInt(default=1), @@ -200,13 +402,20 @@ def test_bigint_column(self): BigInt(index=True), BigInt(index=False), ] - ] + ], + test_function=lambda x: all( + [ + x.data_type == "bigint", + x.is_nullable == "NO", + x.column_default in ("0", "0:::INT8"), # Cockroach DB. + ] + ), ) def test_uuid_column(self): self._test_migrations( - table_classes=[ - self.table(column) + table_snapshots=[ + [self.table(column)] for column in [ UUID(), UUID(default="ecf338cd-6da7-464c-b31e-4b2e3e12f0f0"), @@ -223,13 +432,20 @@ def test_uuid_column(self): UUID(index=True), UUID(index=False), ] - ] + ], + test_function=lambda x: all( + [ + x.data_type == "uuid", + x.is_nullable == "NO", + x.column_default == "gen_random_uuid()", + ] + ), ) def test_timestamp_column(self): self._test_migrations( - table_classes=[ - self.table(column) + table_snapshots=[ + [self.table(column)] for column in [ Timestamp(), Timestamp( @@ -242,13 +458,27 @@ def test_timestamp_column(self): Timestamp(index=True), Timestamp(index=False), ] - ] + ], + test_function=lambda x: all( + [ + x.data_type == "timestamp without time zone", + x.is_nullable == "NO", + x.column_default + in ( + "now()", + "CURRENT_TIMESTAMP", + "current_timestamp()::TIMESTAMP", + "current_timestamp():::TIMESTAMPTZ::TIMESTAMP", + ), + ] + ), ) + @engines_skip("cockroach") def test_time_column(self): self._test_migrations( - table_classes=[ - self.table(column) + table_snapshots=[ + [self.table(column)] for column in [ Time(), Time(default=datetime.time(hour=12, minute=0)), @@ -259,13 +489,21 @@ def test_time_column(self): Time(index=True), Time(index=False), ] - ] + ], + test_function=lambda x: all( + [ + x.data_type == "time without time zone", + x.is_nullable == "NO", + x.column_default + in ("('now'::text)::time with time zone", "CURRENT_TIME"), + ] + ), ) def test_date_column(self): self._test_migrations( - table_classes=[ - self.table(column) + table_snapshots=[ + [self.table(column)] for column in [ Date(), Date(default=datetime.date(year=2021, month=1, day=1)), @@ -276,13 +514,25 @@ def test_date_column(self): Date(index=True), Date(index=False), ] - ] + ], + test_function=lambda x: all( + [ + x.data_type == "date", + x.is_nullable == "NO", + x.column_default + in ( + "('now'::text)::date", + "CURRENT_DATE", + "current_date()", + ), + ] + ), ) def test_interval_column(self): self._test_migrations( - table_classes=[ - self.table(column) + table_snapshots=[ + [self.table(column)] for column in [ Interval(), Interval(default=datetime.timedelta(days=1)), @@ -292,13 +542,25 @@ def test_interval_column(self): Interval(index=True), Interval(index=False), ] - ] + ], + test_function=lambda x: all( + [ + x.data_type == "interval", + x.is_nullable == "NO", + x.column_default + in ( + "'00:00:00'", + "'00:00:00'::interval", + "'00:00:00':::INTERVAL", + ), + ] + ), ) def test_boolean_column(self): self._test_migrations( - table_classes=[ - self.table(column) + table_snapshots=[ + [self.table(column)] for column in [ Boolean(), Boolean(default=True), @@ -308,5 +570,960 @@ def test_boolean_column(self): Boolean(index=True), Boolean(index=False), ] + ], + test_function=lambda x: all( + [ + x.data_type == "boolean", + x.is_nullable == "NO", + x.column_default == "false", + ] + ), + ) + + @engines_skip("cockroach") + def test_numeric_column(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Numeric(), + Numeric(digits=(4, 2)), + Numeric(digits=None), + Numeric(default=decimal.Decimal("1.2")), + Numeric(default=numeric_default), + Numeric(null=True, default=None), + Numeric(null=False), + Numeric(index=True), + Numeric(index=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "numeric", + x.is_nullable == "NO", + x.column_default == "0", + ] + ), + ) + + @engines_skip("cockroach") + def test_decimal_column(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Decimal(), + Decimal(digits=(4, 2)), + Decimal(digits=None), + Decimal(default=decimal.Decimal("1.2")), + Decimal(default=numeric_default), + Decimal(null=True, default=None), + Decimal(null=False), + Decimal(index=True), + Decimal(index=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "numeric", + x.is_nullable == "NO", + x.column_default == "0", + ] + ), + ) + + @engines_skip("cockroach") + def test_array_column_integer(self): + """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/35730 "column my_column is of type int[] and thus is not indexable" + """ # noqa: E501 + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Array(base_column=Integer()), + Array(base_column=Integer(), default=[1, 2, 3]), + Array( + base_column=Integer(), default=array_default_integer + ), + Array(base_column=Integer(), null=True, default=None), + Array(base_column=Integer(), null=False), + Array(base_column=Integer(), index=True), + Array(base_column=Integer(), index=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "ARRAY", + x.is_nullable == "NO", + x.column_default == "'{}'::integer[]", + ] + ), + ) + + @engines_skip("cockroach") + def test_array_column_varchar(self): + """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/35730 "column my_column is of type varchar[] and thus is not indexable" + """ # noqa: E501 + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Array(base_column=Varchar()), + Array(base_column=Varchar(), default=["a", "b", "c"]), + Array( + base_column=Varchar(), default=array_default_varchar + ), + Array(base_column=Varchar(), null=True, default=None), + Array(base_column=Varchar(), null=False), + Array(base_column=Varchar(), index=True), + Array(base_column=Varchar(), index=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "ARRAY", + x.is_nullable == "NO", + x.column_default + in ("'{}'::character varying[]", "'':::STRING"), + ] + ), + ) + + def test_array_column_bigint(self): + """ + There was a bug with using an array of ``BigInt``: + + http://github.com/piccolo-orm/piccolo/issues/500/ + + It's because ``BigInt`` requires access to the parent table to + determine what the column type is. + + """ + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Array(base_column=BigInt()), + ] + ] + ) + + def test_array_base_column_change(self): + """ + There was a bug when trying to change the base column of an array: + + https://github.com/piccolo-orm/piccolo/issues/1076 + + It wasn't importing the base column, e.g. for ``Array(Text())`` it + wasn't importing ``Text``. + + """ + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Array(base_column=Varchar()), + Array(base_column=Text()), + ] + ] + ) + + ########################################################################### + + # We deliberately don't test setting JSON or JSONB columns as indexes, as + # we know it'll fail. + + @engines_skip("cockroach") + def test_json_column(self): + """ + Cockroach sees all json as jsonb, so we can skip this. + """ + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + JSON(), + JSON(default=["a", "b", "c"]), + JSON(default={"name": "bob"}), + JSON(default='{"name": "Sally"}'), + JSON(null=True, default=None), + JSON(null=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "json", + x.is_nullable == "NO", + x.column_default == "'{}'::json", + ] + ), + ) + + def test_jsonb_column(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + JSONB(), + JSONB(default=["a", "b", "c"]), + JSONB(default={"name": "bob"}), + JSONB(default='{"name": "Sally"}'), + JSONB(null=True, default=None), + JSONB(null=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "jsonb", + x.is_nullable == "NO", + x.column_default + in ( + "'{}'", + "'{}'::jsonb", + "'{}':::JSONB", + ), + ] + ), + ) + + ########################################################################### + + def test_db_column_name(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Varchar(), + Varchar(db_column_name="custom_name"), + Varchar(), + Varchar(db_column_name="custom_name_2"), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "character varying", + x.is_nullable == "NO", + x.column_default + in ( + "''", + "''::character varying", + "'':::STRING", + ), + ] + ), + ) + + def test_db_column_name_initial(self): + """ + Make sure that if a new table is created which contains a column with + ``db_column_name`` specified, then the column has the correct name. + """ + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Varchar(db_column_name="custom_name"), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "character varying", + x.is_nullable == "NO", + x.column_default + in ( + "''", + "''::character varying", + "'':::STRING", + ), + ] + ), + ) + + ########################################################################### + + # Column type conversion + + def test_column_type_conversion_string(self): + """ + We can't manage all column type conversions, but should be able to + manage most simple ones (e.g. Varchar to Text). + """ + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Varchar(), + Text(), + Varchar(), + ] + ] + ) + + @engines_skip("cockroach") + def test_column_type_conversion_integer(self): + """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/49351 "ALTER COLUMN TYPE is not supported inside a transaction" + """ # noqa: E501 + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Integer(), + BigInt(), + SmallInt(), + BigInt(), + Integer(), + ] + ] + ) + + @engines_skip("cockroach") + def test_column_type_conversion_string_to_integer(self): + """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/49351 "ALTER COLUMN TYPE is not supported inside a transaction" + """ # noqa: E501 + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Varchar(default="1"), + Integer(default=1), + Varchar(default="1"), + ] + ] + ) + + @engines_skip("cockroach") + def test_column_type_conversion_float_decimal(self): + """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/49351 "ALTER COLUMN TYPE is not supported inside a transaction" + """ # noqa: E501 + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Real(default=1.0), + DoublePrecision(default=1.0), + Real(default=1.0), + Numeric(), + Real(default=1.0), + ] + ] + ) + + def test_column_type_conversion_integer_float(self): + """ + Make sure conversion between ``Integer`` and ``Real`` works - related + to this bug: + + https://github.com/piccolo-orm/piccolo/issues/1071 + + """ + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Real(default=1.0), + Integer(default=1), + Real(default=1.0), + ] + ] + ) + + def test_column_type_conversion_json(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + JSON(), + JSONB(), + JSON(), + ] ] ) + + def test_column_type_conversion_timestamp(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Timestamp(), + Timestamptz(), + Timestamp(), + ] + ] + ) + + @patch("piccolo.apps.migrations.auto.migration_manager.colored_warning") + def test_column_type_conversion_serial(self, colored_warning: MagicMock): + """ + This isn't possible, as neither SERIAL or BIGSERIAL are actual types. + They're just shortcuts. Make sure the migration doesn't crash - it + should just output a warning. + """ + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Serial(), + BigSerial(), + ] + ] + ) + + colored_warning.assert_called_once_with( + "Unable to migrate Serial to BigSerial and vice versa. This must " + "be done manually." + ) + + +############################################################################### + + +class Band(Table): + name = Varchar() + genres = M2M(LazyTableReference("GenreToBand", module_path=__name__)) + + +class Genre(Table): + name = Varchar() + bands = M2M(LazyTableReference("GenreToBand", module_path=__name__)) + + +class GenreToBand(Table): + band = ForeignKey(Band) + genre = ForeignKey(Genre) + + +@engines_only("postgres", "cockroach") +class TestM2MMigrations(MigrationTestCase): + def setUp(self): + pass + + def tearDown(self): + drop_db_tables_sync(Migration, Band, Genre, GenreToBand) + + def test_m2m(self): + """ + Make sure M2M relations can be created. + """ + + self._test_migrations( + table_snapshots=[[Band, Genre, GenreToBand]], + ) + + for table_class in [Band, Genre, GenreToBand]: + self.assertTrue(table_class.table_exists().run_sync()) + + +############################################################################### + + +@engines_only("postgres", "cockroach") +class TestForeignKeys(MigrationTestCase): + def setUp(self): + class TableA(Table): + pass + + class TableB(Table): + fk = ForeignKey(TableA) + + class TableC(Table): + fk = ForeignKey(TableB) + + class TableD(Table): + fk = ForeignKey(TableC) + + class TableE(Table): + fk = ForeignKey(TableD) + + self.table_classes = [TableA, TableB, TableC, TableD, TableE] + + def tearDown(self): + drop_db_tables_sync(Migration, *self.table_classes) + + def test_foreign_keys(self): + """ + Make sure that if we try creating tables with lots of foreign keys + to each other it runs successfully. + + https://github.com/piccolo-orm/piccolo/issues/616 + + """ + # We'll shuffle them, to make it a more thorough test. + table_classes = random.sample( + self.table_classes, len(self.table_classes) + ) + + self._test_migrations(table_snapshots=[table_classes]) + for table_class in table_classes: + self.assertTrue(table_class.table_exists().run_sync()) + + +@engines_only("postgres", "cockroach") +class TestTargetColumn(MigrationTestCase): + def setUp(self): + class TableA(Table): + name = Varchar(unique=True) + + class TableB(Table): + table_a = ForeignKey(TableA, target_column=TableA.name) + + self.table_classes = [TableA, TableB] + + def tearDown(self): + drop_db_tables_sync(Migration, *self.table_classes) + + def test_target_column(self): + """ + Make sure migrations still work when a foreign key references a column + other than the primary key. + """ + self._test_migrations( + table_snapshots=[self.table_classes], + ) + + for table_class in self.table_classes: + self.assertTrue(table_class.table_exists().run_sync()) + + # Make sure the constraint was created correctly. + response = self.run_sync( + """ + SELECT EXISTS( + SELECT 1 + FROM INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE CCU + JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS TC ON + CCU.CONSTRAINT_NAME = TC.CONSTRAINT_NAME + WHERE CONSTRAINT_TYPE = 'FOREIGN KEY' + AND TC.TABLE_NAME = 'table_b' + AND CCU.TABLE_NAME = 'table_a' + AND CCU.COLUMN_NAME = 'name' + ) + """ + ) + self.assertTrue(response[0]["exists"]) + + +@engines_only("postgres", "cockroach") +class TestForeignKeySelf(MigrationTestCase): + def setUp(self) -> None: + class TableA(Table): + id = UUID(primary_key=True) + table_a: ForeignKey[TableA] = ForeignKey("self") + + self.table_classes: list[type[Table]] = [TableA] + + def tearDown(self): + drop_db_tables_sync(Migration, *self.table_classes) + + def test_create_table(self): + """ + Make sure migrations still work when: + + * Creating a new table with a foreign key which references itself. + * The table has a custom primary key type (e.g. UUID). + + """ + self._test_migrations( + table_snapshots=[self.table_classes], + test_function=lambda x: x.data_type == "uuid", + ) + + for table_class in self.table_classes: + self.assertTrue(table_class.table_exists().run_sync()) + + +@engines_only("postgres", "cockroach") +class TestAddForeignKeySelf(MigrationTestCase): + def setUp(self): + pass + + def tearDown(self): + drop_db_tables_sync(create_table_class("MyTable"), Migration) + + @patch("piccolo.conf.apps.Finder.get_app_config") + def test_add_column(self, get_app_config): + """ + Make sure migrations still work when: + + * A foreign key is added to an existing table. + * The foreign key references its own table. + * The table has a custom primary key (e.g. UUID). + + """ + get_app_config.return_value = self._get_app_config() + + self._test_migrations( + table_snapshots=[ + [ + create_table_class( + class_name="MyTable", + class_members={"id": UUID(primary_key=True)}, + ) + ], + [ + create_table_class( + class_name="MyTable", + class_members={ + "id": UUID(primary_key=True), + "fk": ForeignKey("self"), + }, + ) + ], + ], + test_function=lambda x: x.data_type == "uuid", + ) + + +############################################################################### +# Testing migrations which involve schemas. + + +@engines_only("postgres", "cockroach") +class TestSchemas(MigrationTestCase): + new_schema = "schema_1" + + def setUp(self) -> None: + self.schema_manager = SchemaManager() + self.manager_1 = create_table_class(class_name="Manager") + self.manager_2 = create_table_class( + class_name="Manager", class_kwargs={"schema": self.new_schema} + ) + + def tearDown(self) -> None: + self.schema_manager.drop_schema( + self.new_schema, if_exists=True, cascade=True + ).run_sync() + + Migration.alter().drop_table(if_exists=True).run_sync() + + self.manager_1.alter().drop_table(if_exists=True).run_sync() + + def test_create_table_in_schema(self): + """ + Make sure we can create a new table in a schema. + """ + self._test_migrations(table_snapshots=[[self.manager_2]]) + + # The schema should automaticaly be created. + self.assertIn( + self.new_schema, + self.schema_manager.list_schemas().run_sync(), + ) + + # Make sure that the table is in the new schema. + self.assertListEqual( + self.schema_manager.list_tables( + schema_name=self.new_schema + ).run_sync(), + ["manager"], + ) + + # Roll it backwards to make sure the table no longer exists. + self._run_backwards(migration_id="1") + + # Make sure that the table is in the new schema. + self.assertNotIn( + "manager", + self.schema_manager.list_tables( + schema_name=self.new_schema + ).run_sync(), + ) + + def test_move_table_from_public_schema(self): + """ + Make sure the auto migrations can move a table from the public schema + to a different schema. + """ + self._test_migrations( + table_snapshots=[ + [self.manager_1], + [self.manager_2], + ], + ) + + # The schema should automaticaly be created. + self.assertIn( + self.new_schema, + self.schema_manager.list_schemas().run_sync(), + ) + + # Make sure that the table is in the new schema. + self.assertListEqual( + self.schema_manager.list_tables( + schema_name=self.new_schema + ).run_sync(), + ["manager"], + ) + + ####################################################################### + + # Reverse the last migration, which should move the table back to the + # public schema. + self._run_backwards(migration_id="1") + + self.assertIn( + "manager", + self.schema_manager.list_tables(schema_name="public").run_sync(), + ) + + # We don't delete the schema we created as it's risky, just in case + # other tables etc were manually added to it. + self.assertIn( + self.new_schema, + self.schema_manager.list_schemas().run_sync(), + ) + + def test_move_table_to_public_schema(self): + """ + Make sure the auto migrations can move a table from a schema to the + public schema. + """ + self._test_migrations( + table_snapshots=[ + [self.manager_2], + [self.manager_1], + ], + ) + + # Make sure that the table is in the public schema. + self.assertIn( + "manager", + self.schema_manager.list_tables(schema_name="public").run_sync(), + ) + + ####################################################################### + + # Reverse the last migration, which should move the table back to the + # non-public schema. + self._run_backwards(migration_id="1") + + self.assertIn( + "manager", + self.schema_manager.list_tables( + schema_name=self.new_schema + ).run_sync(), + ) + + def test_altering_table_in_schema(self): + """ + Make sure tables in schemas can be altered. + + https://github.com/piccolo-orm/piccolo/issues/883 + + """ + self._test_migrations( + table_snapshots=[ + # Create a table with a single column + [ + create_table_class( + class_name="Manager", + class_kwargs={"schema": self.new_schema}, + class_members={"first_name": Varchar()}, + ) + ], + # Rename the column + [ + create_table_class( + class_name="Manager", + class_kwargs={"schema": self.new_schema}, + class_members={"name": Varchar()}, + ) + ], + # Add a column + [ + create_table_class( + class_name="Manager", + class_kwargs={"schema": self.new_schema}, + class_members={ + "name": Varchar(), + "age": Integer(), + }, + ) + ], + # Remove a column + [ + create_table_class( + class_name="Manager", + class_kwargs={"schema": self.new_schema}, + class_members={ + "name": Varchar(), + }, + ) + ], + # Alter a column + [ + create_table_class( + class_name="Manager", + class_kwargs={"schema": self.new_schema}, + class_members={ + "name": Varchar(length=512), + }, + ) + ], + ], + test_function=lambda x: all( + [ + x.column_name == "name", + x.data_type == "character varying", + x.character_maximum_length == 512, + ] + ), + ) + + +@engines_only("postgres", "cockroach") +class TestSameTableName(MigrationTestCase): + """ + Tables with the same name are allowed in multiple schemas. + """ + + new_schema = "schema_1" + tablename = "manager" + + def setUp(self) -> None: + self.schema_manager = SchemaManager() + + self.manager_1 = create_table_class( + class_name="Manager1", class_kwargs={"tablename": self.tablename} + ) + + self.manager_2 = create_table_class( + class_name="Manager2", + class_kwargs={"tablename": self.tablename, "schema": "schema_1"}, + ) + + def tearDown(self) -> None: + self.schema_manager.drop_schema( + self.new_schema, if_exists=True, cascade=True + ).run_sync() + + self.manager_1.alter().drop_table(if_exists=True).run_sync() + + Migration.alter().drop_table(if_exists=True).run_sync() + + def test_schemas(self): + """ + Make sure we can create a table with the same name in multiple schemas. + """ + + self._test_migrations( + table_snapshots=[ + [self.manager_1], + [self.manager_1, self.manager_2], + ], + ) + + # Make sure that both tables exist (in the correct schemas): + self.assertIn( + "manager", + self.schema_manager.list_tables(schema_name="public").run_sync(), + ) + self.assertIn( + "manager", + self.schema_manager.list_tables( + schema_name=self.new_schema + ).run_sync(), + ) + + +@engines_only("postgres", "cockroach") +class TestForeignKeyWithSchema(MigrationTestCase): + """ + Make sure that migrations with foreign keys involving schemas work + correctly. + """ + + schema = "schema_1" + schema_manager = SchemaManager() + + def setUp(self) -> None: + self.manager = create_table_class( + class_name="Manager", class_kwargs={"schema": self.schema} + ) + + self.band = create_table_class( + class_name="Band", + class_kwargs={"schema": self.schema}, + class_members={"manager": ForeignKey(self.manager)}, + ) + + def tearDown(self) -> None: + self.schema_manager.drop_schema( + self.schema, if_exists=True, cascade=True + ).run_sync() + + Migration.alter().drop_table(if_exists=True).run_sync() + + def test_foreign_key(self): + self._test_migrations( + table_snapshots=[ + [self.manager, self.band], + ], + ) + + tables_in_schema = self.schema_manager.list_tables( + schema_name=self.schema + ).run_sync() + + # Make sure that both tables exist (in the correct schemas): + for tablename in ("manager", "band"): + self.assertIn(tablename, tables_in_schema) + + +############################################################################### + + +@engines_only("postgres", "cockroach") +class TestRenameTable(MigrationTestCase): + """ + Make sure that tables can be renamed. + """ + + schema_manager = SchemaManager() + manager = create_table_class( + class_name="Manager", class_members={"name": Varchar()} + ) + manager_1 = create_table_class( + class_name="Manager", + class_kwargs={"tablename": "manager_1"}, + class_members={"name": Varchar()}, + ) + + def setUp(self) -> None: + pass + + def tearDown(self) -> None: + drop_db_tables_sync(self.manager, self.manager_1, Migration) + + def test_rename_table(self): + self._test_migrations( + table_snapshots=[ + [self.manager], + [self.manager_1], + ], + ) + + tables = self.schema_manager.list_tables( + schema_name="public" + ).run_sync() + + self.assertIn("manager_1", tables) + self.assertNotIn("manager", tables) + + # Make sure the table was renamed, and not dropped and recreated. + migration_managers = self._get_migration_managers() + + self.assertListEqual( + migration_managers[-1].rename_tables, + [ + RenameTable( + old_class_name="Manager", + old_tablename="manager", + new_class_name="Manager", + new_tablename="manager_1", + ) + ], + ) diff --git a/tests/apps/migrations/auto/test_diffable_table.py b/tests/apps/migrations/auto/test_diffable_table.py index 327ed6b59..cacd6d612 100644 --- a/tests/apps/migrations/auto/test_diffable_table.py +++ b/tests/apps/migrations/auto/test_diffable_table.py @@ -4,16 +4,88 @@ DiffableTable, compare_dicts, ) -from piccolo.columns import Varchar +from piccolo.columns import OnDelete, Varchar class TestCompareDicts(TestCase): - def test_compare_dicts(self): + def test_simple(self): + """ + Make sure that simple values are compared properly. + """ dict_1 = {"a": 1, "b": 2} dict_2 = {"a": 1, "b": 3} response = compare_dicts(dict_1, dict_2) self.assertEqual(response, {"b": 2}) + def test_missing_keys(self): + """ + Make sure that if one dictionary has keys that the other doesn't, + it works as expected. + """ + dict_1 = {"a": 1} + dict_2 = {"b": 2, "c": 3} + response = compare_dicts(dict_1, dict_2) + self.assertEqual(response, {"a": 1}) + + def test_list_value(self): + """ + Make sure list values work correctly. + """ + dict_1 = {"a": 1, "b": [1]} + dict_2 = {"a": 1, "b": [2]} + response = compare_dicts(dict_1, dict_2) + self.assertEqual(response, {"b": [1]}) + + def test_dict_value(self): + """ + Make sure dictionary values work correctly. + """ + dict_1 = {"a": 1, "b": {"x": 1}} + dict_2 = {"a": 1, "b": {"x": 1}} + response = compare_dicts(dict_1, dict_2) + self.assertEqual(response, {}) + + dict_1 = {"a": 1, "b": {"x": 1}} + dict_2 = {"a": 1, "b": {"x": 2}} + response = compare_dicts(dict_1, dict_2) + self.assertEqual(response, {"b": {"x": 1}}) + + def test_none_values(self): + """ + Make sure there are no edge cases when using None values. + """ + dict_1 = {"a": None, "b": 1} + dict_2 = {"a": None} + response = compare_dicts(dict_1, dict_2) + self.assertEqual(response, {"b": 1}) + + def test_enum_values(self): + """ + Make sure Enum values can be compared correctly. + """ + dict_1 = {"a": OnDelete.cascade} + dict_2 = {"a": OnDelete.cascade} + response = compare_dicts(dict_1, dict_2) + self.assertEqual(response, {}) + + dict_1 = {"a": OnDelete.set_default} + dict_2 = {"a": OnDelete.cascade} + response = compare_dicts(dict_1, dict_2) + self.assertEqual(response, {"a": OnDelete.set_default}) + + def test_numeric_values(self): + """ + Make sure that if we have two numbers which are equal, but different + types, then they are identified as being different. + + https://github.com/piccolo-orm/piccolo/issues/1071 + + """ + dict_1 = {"a": 1} + dict_2 = {"a": 1.0} + response = compare_dicts(dict_1, dict_2) + self.assertEqual(response, {"a": 1}) + class TestDiffableTable(TestCase): def test_subtract(self): diff --git a/tests/apps/migrations/auto/test_migration_manager.py b/tests/apps/migrations/auto/test_migration_manager.py index 965473613..0952e1895 100644 --- a/tests/apps/migrations/auto/test_migration_manager.py +++ b/tests/apps/migrations/auto/test_migration_manager.py @@ -1,21 +1,122 @@ import asyncio +import random +from io import StringIO +from typing import Optional +from unittest import IsolatedAsyncioTestCase, TestCase from unittest.mock import MagicMock, patch -from piccolo.apps.migrations.auto import MigrationManager +from piccolo.apps.migrations.auto.migration_manager import MigrationManager from piccolo.apps.migrations.commands.base import BaseMigrationManager from piccolo.columns import Text, Varchar from piccolo.columns.base import OnDelete, OnUpdate from piccolo.columns.column_types import ForeignKey from piccolo.conf.apps import AppConfig +from piccolo.engine import engine_finder +from piccolo.query.constraints import get_fk_constraint_rules +from piccolo.table import Table, sort_table_classes from piccolo.utils.lazy_loader import LazyLoader -from tests.base import DBTestCase, postgres_only, set_mock_return_value -from tests.example_app.tables import Manager +from piccolo.utils.sync import run_sync +from tests.base import AsyncMock, DBTestCase, engine_is, engines_only +from tests.example_apps.music.tables import Band, Concert, Manager, Venue asyncpg = LazyLoader("asyncpg", globals(), "asyncpg") +class TestSortTableClasses(TestCase): + def test_sort_table_classes(self): + """ + Make sure simple use cases work correctly. + """ + self.assertListEqual( + sort_table_classes([Manager, Band]), [Manager, Band] + ) + self.assertListEqual( + sort_table_classes([Band, Manager]), [Manager, Band] + ) + + sorted_tables = sort_table_classes([Manager, Venue, Concert, Band]) + self.assertTrue( + sorted_tables.index(Manager) < sorted_tables.index(Band) + ) + self.assertTrue( + sorted_tables.index(Venue) < sorted_tables.index(Concert) + ) + self.assertTrue( + sorted_tables.index(Band) < sorted_tables.index(Concert) + ) + + def test_sort_unrelated_tables(self): + """ + Make sure there are no weird edge cases with tables with no foreign + key relationships with each other. + """ + + class TableA(Table): + pass + + class TableB(Table): + pass + + self.assertListEqual( + sort_table_classes([TableA, TableB]), [TableA, TableB] + ) + + def test_single_table(self): + """ + Make sure that sorting a list with only a single table in it still + works. + """ + self.assertListEqual(sort_table_classes([Band]), [Band]) + + def test_recursive_table(self): + """ + Make sure that a table with a foreign key to itself sorts without + issues. + """ + + class TableA(Table): + table_a = ForeignKey("self") + + class TableB(Table): + table_a = ForeignKey(TableA) + + self.assertListEqual( + sort_table_classes([TableA, TableB]), [TableA, TableB] + ) + + def test_long_chain(self): + """ + Make sure sorting works when there are a lot of tables with foreign + keys to each other. + + https://github.com/piccolo-orm/piccolo/issues/616 + + """ + + class TableA(Table): + pass + + class TableB(Table): + fk = ForeignKey(TableA) + + class TableC(Table): + fk = ForeignKey(TableB) + + class TableD(Table): + fk = ForeignKey(TableC) + + class TableE(Table): + fk = ForeignKey(TableD) + + tables = [TableA, TableB, TableC, TableD, TableE] + + shuffled_tables = random.sample(tables, len(tables)) + + self.assertListEqual(sort_table_classes(shuffled_tables), tables) + + class TestMigrationManager(DBTestCase): - @postgres_only + @engines_only("postgres", "cockroach") def test_rename_column(self): """ Test running a MigrationManager which contains a column rename @@ -37,7 +138,19 @@ def test_rename_column(self): self.assertTrue("name" not in response[0].keys()) # Reverse - asyncio.run(manager.run_backwards()) + asyncio.run(manager.run(backwards=True)) + response = self.run_sync("SELECT * FROM band;") + self.assertTrue("title" not in response[0].keys()) + self.assertTrue("name" in response[0].keys()) + + # Preview + manager.preview = True + with patch("sys.stdout", new=StringIO()) as fake_out: + asyncio.run(manager.run()) + self.assertEqual( + fake_out.getvalue(), + """ - [preview forwards]... \n ALTER TABLE "band" RENAME COLUMN "name" TO "title";\n""", # noqa: E501 + ) response = self.run_sync("SELECT * FROM band;") self.assertTrue("title" not in response[0].keys()) self.assertTrue("name" in response[0].keys()) @@ -50,19 +163,25 @@ def test_raw_function(self): class HasRun(Exception): pass + class HasRunBackwards(Exception): + pass + def run(): raise HasRun("I was run!") + def run_back(): + raise HasRunBackwards("I was run backwards!") + manager = MigrationManager() manager.add_raw(run) - manager.add_raw_backwards(run) + manager.add_raw_backwards(run_back) with self.assertRaises(HasRun): asyncio.run(manager.run()) # Reverse - with self.assertRaises(HasRun): - asyncio.run(manager.run_backwards()) + with self.assertRaises(HasRunBackwards): + asyncio.run(manager.run(backwards=True)) def test_raw_coroutine(self): """ @@ -72,21 +191,27 @@ def test_raw_coroutine(self): class HasRun(Exception): pass + class HasRunBackwards(Exception): + pass + async def run(): raise HasRun("I was run!") + async def run_back(): + raise HasRunBackwards("I was run backwards!") + manager = MigrationManager() manager.add_raw(run) - manager.add_raw_backwards(run) + manager.add_raw_backwards(run_back) with self.assertRaises(HasRun): asyncio.run(manager.run()) # Reverse - with self.assertRaises(HasRun): - asyncio.run(manager.run_backwards()) + with self.assertRaises(HasRunBackwards): + asyncio.run(manager.run(backwards=True)) - @postgres_only + @engines_only("postgres", "cockroach") @patch.object(BaseMigrationManager, "get_app_config") def test_add_table(self, get_app_config: MagicMock): """ @@ -104,22 +229,49 @@ def test_add_table(self, get_app_config: MagicMock): ) asyncio.run(manager.run()) - self.run_sync("INSERT INTO musician VALUES (default, 'Bob Jones');") - response = self.run_sync("SELECT * FROM musician;") + if engine_is("postgres"): + self.run_sync( + "INSERT INTO musician VALUES (default, 'Bob Jones');" + ) + response = self.run_sync("SELECT * FROM musician;") + self.assertEqual(response, [{"id": 1, "name": "Bob Jones"}]) - self.assertEqual(response, [{"id": 1, "name": "Bob Jones"}]) + if engine_is("cockroach"): + id = self.run_sync( + "INSERT INTO musician VALUES (default, 'Bob Jones') RETURNING id;" # noqa: E501 + ) + response = self.run_sync("SELECT * FROM musician;") + self.assertEqual( + response, [{"id": id[0]["id"], "name": "Bob Jones"}] + ) # Reverse - get_app_config.return_value = AppConfig( app_name="music", migrations_folder_path="" ) - asyncio.run(manager.run_backwards()) + asyncio.run(manager.run(backwards=True)) self.assertEqual(self.table_exists("musician"), False) self.run_sync("DROP TABLE IF EXISTS musician;") - @postgres_only - def test_add_column(self): + # Preview + manager.preview = True + with patch("sys.stdout", new=StringIO()) as fake_out: + asyncio.run(manager.run()) + + if engine_is("postgres"): + self.assertEqual( + fake_out.getvalue(), + """ - [preview forwards]... \n CREATE TABLE "musician" ("id" SERIAL PRIMARY KEY NOT NULL, "name" VARCHAR(255) NOT NULL DEFAULT '');\n""", # noqa: E501 + ) + if engine_is("cockroach"): + self.assertEqual( + fake_out.getvalue(), + """ - [preview forwards]... \n CREATE TABLE "musician" ("id" INTEGER PRIMARY KEY NOT NULL DEFAULT unique_rowid(), "name" VARCHAR(255) NOT NULL DEFAULT '');\n""", # noqa: E501 + ) + self.assertEqual(self.table_exists("musician"), False) + + @engines_only("postgres", "cockroach") + def test_add_column(self) -> None: """ Test adding a column to a MigrationManager. """ @@ -134,29 +286,59 @@ def test_add_column(self): "length": 100, "default": "", "null": True, - "primary": False, - "key": False, + "primary_key": False, "unique": True, "index": False, }, ) asyncio.run(manager.run()) - self.run_sync( - "INSERT INTO manager VALUES (default, 'Dave', 'dave@me.com');" - ) + if engine_is("postgres"): + self.run_sync( + "INSERT INTO \"manager\" VALUES (default, 'Dave', 'dave@me.com');" # noqa: E501 + ) + response = self.run_sync("SELECT * FROM manager;") + self.assertEqual( + response, [{"id": 1, "name": "Dave", "email": "dave@me.com"}] + ) - response = self.run_sync("SELECT * FROM manager;") - self.assertEqual( - response, [{"id": 1, "name": "Dave", "email": "dave@me.com"}] - ) + # Reverse + asyncio.run(manager.run(backwards=True)) + response = self.run_sync("SELECT * FROM manager;") + self.assertEqual(response, [{"id": 1, "name": "Dave"}]) + + row_id: Optional[int] = None + if engine_is("cockroach"): + row_id = self.run_sync( + "INSERT INTO manager VALUES (default, 'Dave', 'dave@me.com') RETURNING id;" # noqa: E501 + )[0]["id"] + response = self.run_sync("SELECT * FROM manager;") + self.assertEqual( + response, + [{"id": row_id, "name": "Dave", "email": "dave@me.com"}], + ) + + # Reverse + asyncio.run(manager.run(backwards=True)) + response = self.run_sync("SELECT * FROM manager;") + self.assertEqual(response, [{"id": row_id, "name": "Dave"}]) + + # Preview + manager.preview = True + with patch("sys.stdout", new=StringIO()) as fake_out: + asyncio.run(manager.run()) + self.assertEqual( + fake_out.getvalue(), + """ - [preview forwards]... \n ALTER TABLE "manager" ADD COLUMN "email" VARCHAR(100) UNIQUE DEFAULT '';\n""", # noqa: E501 + ) - # Reverse - asyncio.run(manager.run_backwards()) response = self.run_sync("SELECT * FROM manager;") - self.assertEqual(response, [{"id": 1, "name": "Dave"}]) + if engine_is("postgres"): + self.assertEqual(response, [{"id": 1, "name": "Dave"}]) + if engine_is("cockroach"): + self.assertEqual(response, [{"id": row_id, "name": "Dave"}]) - @postgres_only + @engines_only("postgres", "cockroach") def test_add_column_with_index(self): """ Test adding a column with an index to a MigrationManager. @@ -172,8 +354,7 @@ def test_add_column_with_index(self): "length": 100, "default": "", "null": True, - "primary": False, - "key": False, + "primary_key": False, "unique": True, "index": True, }, @@ -184,10 +365,23 @@ def test_add_column_with_index(self): self.assertTrue(index_name in Manager.indexes().run_sync()) # Reverse - asyncio.run(manager.run_backwards()) + asyncio.run(manager.run(backwards=True)) + self.assertTrue(index_name not in Manager.indexes().run_sync()) + + # Preview + manager.preview = True + with patch("sys.stdout", new=StringIO()) as fake_out: + asyncio.run(manager.run()) + self.assertEqual( + fake_out.getvalue(), + ( + """ - [preview forwards]... \n ALTER TABLE "manager" ADD COLUMN "email" VARCHAR(100) UNIQUE DEFAULT '';\n""" # noqa: E501 + """\n CREATE INDEX manager_email ON "manager" USING btree ("email");\n""" # noqa: E501 + ), + ) self.assertTrue(index_name not in Manager.indexes().run_sync()) - @postgres_only + @engines_only("postgres") def test_add_foreign_key_self_column(self): """ Test adding a ForeignKey column to a MigrationManager, with a @@ -206,8 +400,7 @@ def test_add_foreign_key_self_column(self): "on_update": OnUpdate.cascade, "default": None, "null": True, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, @@ -227,14 +420,68 @@ def test_add_foreign_key_self_column(self): ) # Reverse - asyncio.run(manager.run_backwards()) + asyncio.run(manager.run(backwards=True)) response = self.run_sync("SELECT * FROM manager;") self.assertEqual( response, [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Dave"}], ) - @postgres_only + @engines_only("cockroach") + def test_add_foreign_key_self_column_alt(self): + """ + Test adding a ForeignKey column to a MigrationManager, with a + references argument of 'self'. + """ + manager = MigrationManager() + manager.add_column( + table_class_name="Manager", + tablename="manager", + column_name="advisor", + column_class=ForeignKey, + column_class_name="ForeignKey", + params={ + "references": "self", + "on_delete": OnDelete.cascade, + "on_update": OnUpdate.cascade, + "default": None, + "null": True, + "primary_key": False, + "unique": False, + "index": False, + }, + ) + asyncio.run(manager.run()) + + id = self.run_sync( + "INSERT INTO manager VALUES (default, 'Alice', null) RETURNING id;" + ) + id2 = Manager.raw( + "INSERT INTO manager VALUES (default, 'Dave', {}) RETURNING id;", + id[0]["id"], + ).run_sync() + + response = self.run_sync("SELECT * FROM manager;") + self.assertEqual( + response, + [ + {"id": id[0]["id"], "name": "Alice", "advisor": None}, + {"id": id2[0]["id"], "name": "Dave", "advisor": id[0]["id"]}, + ], + ) + + # Reverse + asyncio.run(manager.run(backwards=True)) + response = self.run_sync("SELECT * FROM manager;") + self.assertEqual( + response, + [ + {"id": id[0]["id"], "name": "Alice"}, + {"id": id2[0]["id"], "name": "Dave"}, + ], + ) + + @engines_only("postgres", "cockroach") def test_add_non_nullable_column(self): """ Test adding a non nullable column to a MigrationManager. @@ -254,16 +501,17 @@ def test_add_non_nullable_column(self): "length": 100, "default": "", "null": False, - "primary": False, - "key": False, + "primary_key": False, "unique": True, "index": False, }, ) asyncio.run(manager.run()) - @postgres_only - @patch.object(BaseMigrationManager, "get_migration_managers") + @engines_only("postgres", "cockroach") + @patch.object( + BaseMigrationManager, "get_migration_managers", new_callable=AsyncMock + ) @patch.object(BaseMigrationManager, "get_app_config") def test_drop_column( self, get_app_config: MagicMock, get_migration_managers: MagicMock @@ -281,9 +529,20 @@ def test_drop_column( ) asyncio.run(manager_1.run()) - self.run_sync("INSERT INTO musician VALUES (default, 'Dave');") - response = self.run_sync("SELECT * FROM musician;") - self.assertEqual(response, [{"id": 1, "name": "Dave"}]) + if engine_is("postgres"): + self.run_sync("INSERT INTO musician VALUES (default, 'Dave');") + response = self.run_sync("SELECT * FROM musician;") + self.assertEqual(response, [{"id": 1, "name": "Dave"}]) + + id = 0 + if engine_is("cockroach"): + id = self.run_sync( + "INSERT INTO musician VALUES (default, 'Dave') RETURNING id;" + ) + response = self.run_sync("SELECT * FROM musician;") + self.assertEqual( + response, [{"id": id[0]["id"], "name": "Dave"}] # type: ignore + ) manager_2 = MigrationManager() manager_2.drop_column( @@ -293,18 +552,29 @@ def test_drop_column( ) asyncio.run(manager_2.run()) - response = self.run_sync("SELECT * FROM musician;") - self.assertEqual(response, [{"id": 1}]) + if engine_is("postgres"): + response = self.run_sync("SELECT * FROM musician;") + self.assertEqual(response, [{"id": 1}]) + + if engine_is("cockroach"): + response = self.run_sync("SELECT * FROM musician;") + self.assertEqual(response, [{"id": id[0]["id"]}]) # type: ignore # Reverse - set_mock_return_value(get_migration_managers, [manager_1]) + get_migration_managers.return_value = [manager_1] app_config = AppConfig(app_name="music", migrations_folder_path="") get_app_config.return_value = app_config - asyncio.run(manager_2.run_backwards()) + asyncio.run(manager_2.run(backwards=True)) response = self.run_sync("SELECT * FROM musician;") - self.assertEqual(response, [{"id": 1, "name": ""}]) + if engine_is("postgres"): + self.assertEqual(response, [{"id": 1, "name": ""}]) - @postgres_only + if engine_is("cockroach"): + self.assertEqual( + response, [{"id": id[0]["id"], "name": ""}] # type: ignore + ) + + @engines_only("postgres", "cockroach") def test_rename_table(self): """ Test renaming a table with MigrationManager. @@ -320,21 +590,85 @@ def test_rename_table(self): asyncio.run(manager.run()) - self.run_sync("INSERT INTO director VALUES (default, 'Dave');") + if engine_is("postgres"): + self.run_sync("INSERT INTO director VALUES (default, 'Dave');") + + response = self.run_sync("SELECT * FROM director;") + self.assertEqual(response, [{"id": 1, "name": "Dave"}]) - response = self.run_sync("SELECT * FROM director;") - self.assertEqual(response, [{"id": 1, "name": "Dave"}]) + # Reverse + asyncio.run(manager.run(backwards=True)) + response = self.run_sync("SELECT * FROM manager;") + self.assertEqual(response, [{"id": 1, "name": "Dave"}]) + + if engine_is("cockroach"): + id = 0 + id = self.run_sync( + "INSERT INTO director VALUES (default, 'Dave') RETURNING id;" + ) + + response = self.run_sync("SELECT * FROM director;") + self.assertEqual(response, [{"id": id[0]["id"], "name": "Dave"}]) + + # Reverse + asyncio.run(manager.run(backwards=True)) + response = self.run_sync("SELECT * FROM manager;") + self.assertEqual(response, [{"id": id[0]["id"], "name": "Dave"}]) + + @engines_only("postgres", "cockroach") + def test_alter_fk_on_delete_on_update(self): + """ + Test altering OnDelete and OnUpdate with MigrationManager. + """ + # before performing migrations - OnDelete.no_action + self.assertEqual( + run_sync(get_fk_constraint_rules(column=Band.manager)).on_delete, + OnDelete.no_action, + ) + + manager = MigrationManager(app_name="music") + manager.alter_column( + table_class_name="Band", + tablename="band", + column_name="manager", + db_column_name="manager", + params={ + "on_delete": OnDelete.set_null, + "on_update": OnUpdate.set_null, + }, + old_params={ + "on_delete": OnDelete.no_action, + "on_update": OnUpdate.no_action, + }, + column_class=ForeignKey, + old_column_class=ForeignKey, + schema=None, + ) + + asyncio.run(manager.run()) + + # after performing migrations - OnDelete.set_null + self.assertEqual( + run_sync(get_fk_constraint_rules(column=Band.manager)).on_delete, + OnDelete.set_null, + ) # Reverse - asyncio.run(manager.run_backwards()) - response = self.run_sync("SELECT * FROM manager;") - self.assertEqual(response, [{"id": 1, "name": "Dave"}]) + asyncio.run(manager.run(backwards=True)) - @postgres_only + # after performing reverse migrations we have + # OnDelete.no_action again + self.assertEqual( + run_sync(get_fk_constraint_rules(column=Band.manager)).on_delete, + OnDelete.no_action, + ) + + @engines_only("postgres") def test_alter_column_unique(self): """ Test altering a column uniqueness with MigrationManager. - """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/42840 "unimplemented: cannot drop UNIQUE constraint "manager_name_key" using ALTER TABLE DROP CONSTRAINT, use DROP INDEX CASCADE instead" + """ # noqa: E501 manager = MigrationManager() manager.alter_column( @@ -354,14 +688,14 @@ def test_alter_column_unique(self): ) # Reverse - asyncio.run(manager.run_backwards()) + asyncio.run(manager.run(backwards=True)) self.run_sync( "INSERT INTO manager VALUES (default, 'Dave'), (default, 'Dave');" ) response = self.run_sync("SELECT name FROM manager;") self.assertEqual(response, [{"name": "Dave"}, {"name": "Dave"}]) - @postgres_only + @engines_only("postgres", "cockroach") def test_alter_column_set_null(self): """ Test altering whether a column is nullable with MigrationManager. @@ -384,7 +718,7 @@ def test_alter_column_set_null(self): ) # Reverse - asyncio.run(manager.run_backwards()) + asyncio.run(manager.run(backwards=True)) self.assertFalse( self.get_postgres_is_nullable( tablename="manager", column_name="name" @@ -409,11 +743,12 @@ def _get_column_default(self, tablename="manager", column_name="name"): f"AND column_name = '{column_name}';" ) - @postgres_only + @engines_only("postgres") def test_alter_column_digits(self): """ Test altering a column digits with MigrationManager. - """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/49351 "ALTER COLUMN TYPE is not supported inside a transaction" + """ # noqa: E501 manager = MigrationManager() manager.alter_column( @@ -430,13 +765,13 @@ def test_alter_column_digits(self): [{"numeric_precision": 6, "numeric_scale": 2}], ) - asyncio.run(manager.run_backwards()) + asyncio.run(manager.run(backwards=True)) self.assertEqual( self._get_column_precision_and_scale(), [{"numeric_precision": 5, "numeric_scale": 2}], ) - @postgres_only + @engines_only("postgres") def test_alter_column_set_default(self): """ Test altering a column default with MigrationManager. @@ -457,13 +792,40 @@ def test_alter_column_set_default(self): [{"column_default": "'Unknown'::character varying"}], ) - asyncio.run(manager.run_backwards()) + asyncio.run(manager.run(backwards=True)) self.assertEqual( self._get_column_default(), [{"column_default": "''::character varying"}], ) - @postgres_only + @engines_only("cockroach") + def test_alter_column_set_default_alt(self): + """ + Test altering a column default with MigrationManager. + """ + manager = MigrationManager() + + manager.alter_column( + table_class_name="Manager", + tablename="manager", + column_name="name", + params={"default": "Unknown"}, + old_params={"default": ""}, + ) + + asyncio.run(manager.run()) + self.assertIn( + self._get_column_default()[0]["column_default"], + ["'Unknown'", "'Unknown':::STRING"], + ) + + asyncio.run(manager.run(backwards=True)) + self.assertIn( + self._get_column_default()[0]["column_default"], + ["''", "'':::STRING"], + ) + + @engines_only("postgres") def test_alter_column_drop_default(self): """ Test setting a column default to None with MigrationManager. @@ -507,25 +869,87 @@ def test_alter_column_drop_default(self): ) # Run them all backwards - asyncio.run(manager_3.run_backwards()) + asyncio.run(manager_3.run(backwards=True)) self.assertEqual( self._get_column_default(), [{"column_default": None}], ) - asyncio.run(manager_2.run_backwards()) + asyncio.run(manager_2.run(backwards=True)) self.assertEqual( self._get_column_default(), [{"column_default": "'Mr Manager'::character varying"}], ) - asyncio.run(manager_1.run_backwards()) + asyncio.run(manager_1.run(backwards=True)) + self.assertEqual( + self._get_column_default(), + [{"column_default": None}], + ) + + @engines_only("cockroach") + def test_alter_column_drop_default_alt(self): + """ + Test setting a column default to None with MigrationManager. + """ + # Make sure it has a non-null default to start with. + manager_1 = MigrationManager() + manager_1.alter_column( + table_class_name="Manager", + tablename="manager", + column_name="name", + params={"default": "Mr Manager"}, + old_params={"default": None}, + ) + asyncio.run(manager_1.run()) + self.assertIn( + self._get_column_default()[0]["column_default"], + ["'Mr Manager'", "'Mr Manager':::STRING"], + ) + + # Drop the default. + manager_2 = MigrationManager() + manager_2.alter_column( + table_class_name="Manager", + tablename="manager", + column_name="name", + params={"default": None}, + old_params={"default": "Mr Manager"}, + ) + asyncio.run(manager_2.run()) + self.assertEqual( + self._get_column_default(), + [{"column_default": None}], + ) + + # And add it back once more to be sure. + manager_3 = manager_1 + asyncio.run(manager_3.run()) + self.assertIn( + self._get_column_default()[0]["column_default"], + ["'Mr Manager'", "'Mr Manager':::STRING"], + ) + + # Run them all backwards + asyncio.run(manager_3.run(backwards=True)) self.assertEqual( self._get_column_default(), [{"column_default": None}], ) - @postgres_only + asyncio.run(manager_2.run(backwards=True)) + self.assertIn( + self._get_column_default()[0]["column_default"], + ["'Mr Manager'", "'Mr Manager':::STRING"], + ) + + asyncio.run(manager_1.run(backwards=True)) + self.assertEqual( + self._get_column_default(), + [{"column_default": None}], + ) + + @engines_only("postgres", "cockroach") def test_alter_column_add_index(self): """ Test altering a column to add an index with MigrationManager. @@ -545,13 +969,13 @@ def test_alter_column_add_index(self): Manager._get_index_name(["name"]) in Manager.indexes().run_sync() ) - asyncio.run(manager.run_backwards()) + asyncio.run(manager.run(backwards=True)) self.assertTrue( Manager._get_index_name(["name"]) not in Manager.indexes().run_sync() ) - @postgres_only + @engines_only("postgres", "cockroach") def test_alter_column_set_type(self): """ Test altering a column to change it's type with MigrationManager. @@ -574,17 +998,18 @@ def test_alter_column_set_type(self): ) self.assertEqual(column_type_str, "TEXT") - asyncio.run(manager.run_backwards()) + asyncio.run(manager.run(backwards=True)) column_type_str = self.get_postgres_column_type( tablename="manager", column_name="name" ) self.assertEqual(column_type_str, "CHARACTER VARYING") - @postgres_only + @engines_only("postgres") def test_alter_column_set_length(self): """ Test altering a Varchar column's length with MigrationManager. - """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/49351 "ALTER COLUMN TYPE is not supported inside a transaction" + """ # noqa: E501 manager = MigrationManager() manager.alter_column( @@ -605,7 +1030,7 @@ def test_alter_column_set_length(self): 500, ) - asyncio.run(manager.run_backwards()) + asyncio.run(manager.run(backwards=True)) self.assertEqual( self.get_postgres_varchar_length( tablename="manager", column_name="name" @@ -613,8 +1038,10 @@ def test_alter_column_set_length(self): 200, ) - @postgres_only - @patch.object(BaseMigrationManager, "get_migration_managers") + @engines_only("postgres", "cockroach") + @patch.object( + BaseMigrationManager, "get_migration_managers", new_callable=AsyncMock + ) @patch.object(BaseMigrationManager, "get_app_config") def test_drop_table( self, get_app_config: MagicMock, get_migration_managers: MagicMock @@ -637,10 +1064,10 @@ def test_drop_table( self.assertTrue(not self.table_exists("musician")) # Reverse - set_mock_return_value(get_migration_managers, [manager_1]) + get_migration_managers.return_value = [manager_1] app_config = AppConfig(app_name="music", migrations_folder_path="") get_app_config.return_value = app_config - asyncio.run(manager_2.run_backwards()) + asyncio.run(manager_2.run(backwards=True)) get_migration_managers.assert_called_with( app_config=app_config, max_migration_id="2", offset=-1 @@ -648,3 +1075,60 @@ def test_drop_table( self.assertTrue(self.table_exists("musician")) self.run_sync("DROP TABLE IF EXISTS musician;") + + @engines_only("postgres", "cockroach") + def test_change_table_schema(self): + manager = MigrationManager(migration_id="1", app_name="music") + + manager.change_table_schema( + class_name="Manager", + tablename="manager", + new_schema="schema_1", + old_schema=None, + ) + + # Preview + manager.preview = True + with patch("sys.stdout", new=StringIO()) as fake_out: + asyncio.run(manager.run()) + + output = fake_out.getvalue() + + self.assertEqual( + output, + ' - 1 [preview forwards]... CREATE SCHEMA IF NOT EXISTS "schema_1"\nALTER TABLE "manager" SET SCHEMA "schema_1"\n', # noqa: E501 + ) + + +class TestWrapInTransaction(IsolatedAsyncioTestCase): + + async def test_enabled(self): + """ + Make sure we can wrap the migration in a transaction if we want to. + """ + + async def run(): + db = engine_finder() + assert db + assert db.transaction_exists() is True + + manager = MigrationManager(wrap_in_transaction=True) + manager.add_raw(run) + + await manager.run() + + async def test_disabled(self): + """ + Make sure we can stop the migration being wrapped in a transaction if + we want to. + """ + + async def run(): + db = engine_finder() + assert db + assert db.transaction_exists() is False + + manager = MigrationManager(wrap_in_transaction=False) + manager.add_raw(run) + + await manager.run() diff --git a/tests/apps/migrations/auto/test_schema_differ.py b/tests/apps/migrations/auto/test_schema_differ.py index 8a2f03de5..35dce6adc 100644 --- a/tests/apps/migrations/auto/test_schema_differ.py +++ b/tests/apps/migrations/auto/test_schema_differ.py @@ -1,25 +1,34 @@ from __future__ import annotations -import typing as t from unittest import TestCase - -from piccolo.apps.migrations.auto import DiffableTable, SchemaDiffer +from unittest.mock import MagicMock, call, patch + +from piccolo.apps.migrations.auto.schema_differ import ( + DiffableTable, + RenameColumn, + RenameColumnCollection, + RenameTable, + RenameTableCollection, + SchemaDiffer, +) from piccolo.columns.column_types import Numeric, Varchar class TestSchemaDiffer(TestCase): - def test_add_table(self): + maxDiff = None + + def test_add_table(self) -> None: """ Test adding a new table. """ name_column = Varchar() name_column._meta.name = "name" - schema: t.List[DiffableTable] = [ + schema: list[DiffableTable] = [ DiffableTable( class_name="Band", tablename="band", columns=[name_column] ) ] - schema_snapshot: t.List[DiffableTable] = [] + schema_snapshot: list[DiffableTable] = [] schema_differ = SchemaDiffer( schema=schema, schema_snapshot=schema_snapshot, auto_input="y" ) @@ -28,22 +37,22 @@ def test_add_table(self): self.assertTrue(len(create_tables.statements) == 1) self.assertEqual( create_tables.statements[0], - "manager.add_table('Band', tablename='band')", + "manager.add_table(class_name='Band', tablename='band', schema=None, columns=None)", # noqa: E501 ) new_table_columns = schema_differ.new_table_columns self.assertTrue(len(new_table_columns.statements) == 1) self.assertEqual( new_table_columns.statements[0], - "manager.add_column(table_class_name='Band', tablename='band', column_name='name', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None})", # noqa + "manager.add_column(table_class_name='Band', tablename='band', column_name='name', db_column_name='name', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False}, schema=None)", # noqa ) - def test_drop_table(self): + def test_drop_table(self) -> None: """ Test dropping an existing table. """ - schema: t.List[DiffableTable] = [] - schema_snapshot: t.List[DiffableTable] = [ + schema: list[DiffableTable] = [] + schema_snapshot: list[DiffableTable] = [ DiffableTable(class_name="Band", tablename="band", columns=[]) ] schema_differ = SchemaDiffer( @@ -53,22 +62,22 @@ def test_drop_table(self): self.assertTrue(len(schema_differ.drop_tables.statements) == 1) self.assertEqual( schema_differ.drop_tables.statements[0], - "manager.drop_table(class_name='Band', tablename='band')", + "manager.drop_table(class_name='Band', tablename='band', schema=None)", # noqa: E501 ) - def test_rename_table(self): + def test_rename_table(self) -> None: """ Test renaming a table. """ name_column = Varchar() name_column._meta.name = "name" - schema: t.List[DiffableTable] = [ + schema: list[DiffableTable] = [ DiffableTable( class_name="Act", tablename="act", columns=[name_column] ) ] - schema_snapshot: t.List[DiffableTable] = [ + schema_snapshot: list[DiffableTable] = [ DiffableTable( class_name="Band", tablename="band", columns=[name_column] ) @@ -81,13 +90,48 @@ def test_rename_table(self): self.assertTrue(len(schema_differ.rename_tables.statements) == 1) self.assertEqual( schema_differ.rename_tables.statements[0], - "manager.rename_table(old_class_name='Band', old_tablename='band', new_class_name='Act', new_tablename='act')", # noqa + "manager.rename_table(old_class_name='Band', old_tablename='band', new_class_name='Act', new_tablename='act', schema=None)", # noqa: E501 ) self.assertEqual(schema_differ.create_tables.statements, []) self.assertEqual(schema_differ.drop_tables.statements, []) - def test_add_column(self): + def test_change_schema(self) -> None: + """ + Testing changing the schema. + """ + schema: list[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[], + schema="schema_1", + ) + ] + schema_snapshot: list[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[], + schema=None, + ) + ] + + schema_differ = SchemaDiffer( + schema=schema, schema_snapshot=schema_snapshot, auto_input="y" + ) + + self.assertEqual(len(schema_differ.change_table_schemas.statements), 1) + + self.assertEqual( + schema_differ.change_table_schemas.statements[0], + "manager.change_table_schema(class_name='Band', tablename='band', new_schema='schema_1', old_schema=None)", # noqa: E501 + ) + + self.assertListEqual(schema_differ.create_tables.statements, []) + self.assertListEqual(schema_differ.drop_tables.statements, []) + + def test_add_column(self) -> None: """ Test adding a column to an existing table. """ @@ -97,14 +141,14 @@ def test_add_column(self): genre_column = Varchar() genre_column._meta.name = "genre" - schema: t.List[DiffableTable] = [ + schema: list[DiffableTable] = [ DiffableTable( class_name="Band", tablename="band", columns=[name_column, genre_column], ) ] - schema_snapshot: t.List[DiffableTable] = [ + schema_snapshot: list[DiffableTable] = [ DiffableTable( class_name="Band", tablename="band", @@ -119,10 +163,10 @@ def test_add_column(self): self.assertTrue(len(schema_differ.add_columns.statements) == 1) self.assertEqual( schema_differ.add_columns.statements[0], - "manager.add_column(table_class_name='Band', tablename='band', column_name='genre', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None})", # noqa + "manager.add_column(table_class_name='Band', tablename='band', column_name='genre', db_column_name='genre', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False}, schema=None)", # noqa: E501 ) - def test_drop_column(self): + def test_drop_column(self) -> None: """ Test dropping a column from an existing table. """ @@ -132,14 +176,14 @@ def test_drop_column(self): genre_column = Varchar() genre_column._meta.name = "genre" - schema: t.List[DiffableTable] = [ + schema: list[DiffableTable] = [ DiffableTable( class_name="Band", tablename="band", columns=[name_column], ) ] - schema_snapshot: t.List[DiffableTable] = [ + schema_snapshot: list[DiffableTable] = [ DiffableTable( class_name="Band", tablename="band", @@ -154,59 +198,240 @@ def test_drop_column(self): self.assertTrue(len(schema_differ.drop_columns.statements) == 1) self.assertEqual( schema_differ.drop_columns.statements[0], - "manager.drop_column(table_class_name='Band', tablename='band', column_name='genre')", # noqa + "manager.drop_column(table_class_name='Band', tablename='band', column_name='genre', db_column_name='genre', schema=None)", # noqa: E501 ) - def test_rename_column(self): + def test_rename_column(self) -> None: """ Test renaming a column in an existing table. """ + # We're going to rename the 'name' column to 'title' name_column = Varchar() name_column._meta.name = "name" title_column = Varchar() title_column._meta.name = "title" - schema: t.List[DiffableTable] = [ + schema_snapshot: list[DiffableTable] = [ DiffableTable( class_name="Band", tablename="band", - columns=[name_column], + columns=[title_column], ) ] - schema_snapshot: t.List[DiffableTable] = [ + schema: list[DiffableTable] = [ DiffableTable( class_name="Band", tablename="band", - columns=[title_column], + columns=[name_column], ) ] + # Test 1 - Tell Piccolo the column was renamed schema_differ = SchemaDiffer( schema=schema, schema_snapshot=schema_snapshot, auto_input="y" ) + self.assertEqual(schema_differ.add_columns.statements, []) + self.assertEqual(schema_differ.drop_columns.statements, []) + self.assertEqual( + schema_differ.rename_columns.statements, + [ + "manager.rename_column(table_class_name='Band', tablename='band', old_column_name='title', new_column_name='name', old_db_column_name='title', new_db_column_name='name', schema=None)" # noqa: E501 + ], + ) + + # Test 2 - Tell Piccolo the column wasn't renamed + schema_differ = SchemaDiffer( + schema=schema, schema_snapshot=schema_snapshot, auto_input="n" + ) + self.assertEqual( + schema_differ.add_columns.statements, + [ + "manager.add_column(table_class_name='Band', tablename='band', column_name='name', db_column_name='name', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False}, schema=None)" # noqa: E501 + ], + ) + self.assertEqual( + schema_differ.drop_columns.statements, + [ + "manager.drop_column(table_class_name='Band', tablename='band', column_name='title', db_column_name='title', schema=None)" # noqa: E501 + ], + ) + self.assertTrue(schema_differ.rename_columns.statements == []) + + @patch("piccolo.apps.migrations.auto.schema_differ.input") + def test_rename_multiple_columns(self, input: MagicMock) -> None: + """ + Make sure renaming columns works when several columns have been + renamed. + """ + # We're going to rename a1 to a2, and b1 to b2. + a1 = Varchar() + a1._meta.name = "a1" + + a2 = Varchar() + a2._meta.name = "a2" + + b1 = Varchar() + b1._meta.name = "b1" + + b2 = Varchar() + b2._meta.name = "b2" + + schema_snapshot: list[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[a1, b1], + ) + ] + schema: list[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[a2, b2], + ) + ] + + def mock_input(value: str): + """ + We need to dynamically set the return value based on what's passed + in. + """ + return ( + "y" + if value + in ( + "Did you rename the `a1` column to `a2` on the `Band` table? (y/N)", # noqa: E501 + "Did you rename the `b1` column to `b2` on the `Band` table? (y/N)", # noqa: E501 + ) + else "n" + ) + + input.side_effect = mock_input + + schema_differ = SchemaDiffer( + schema=schema, schema_snapshot=schema_snapshot + ) + self.assertEqual(schema_differ.add_columns.statements, []) + self.assertEqual(schema_differ.drop_columns.statements, []) + self.assertEqual( + schema_differ.rename_columns.statements, + [ + "manager.rename_column(table_class_name='Band', tablename='band', old_column_name='a1', new_column_name='a2', old_db_column_name='a1', new_db_column_name='a2', schema=None)", # noqa: E501 + "manager.rename_column(table_class_name='Band', tablename='band', old_column_name='b1', new_column_name='b2', old_db_column_name='b1', new_db_column_name='b2', schema=None)", # noqa: E501 + ], + ) + + self.assertEqual( + input.call_args_list, + [ + call( + "Did you rename the `a1` column to `a2` on the `Band` table? (y/N)" # noqa: E501 + ), + call( + "Did you rename the `b1` column to `b2` on the `Band` table? (y/N)" # noqa: E501 + ), + ], + ) + + @patch("piccolo.apps.migrations.auto.schema_differ.input") + def test_rename_some_columns(self, input: MagicMock): + """ + Make sure that some columns can be marked as renamed, and others are + dropped / created. + """ + # We're going to rename a1 to a2, but want b1 to be dropped, and b2 to + # be created. + a1 = Varchar() + a1._meta.name = "a1" + + a2 = Varchar() + a2._meta.name = "a2" + + b1 = Varchar() + b1._meta.name = "b1" + + b2 = Varchar() + b2._meta.name = "b2" + + schema_snapshot: list[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[a1, b1], + ) + ] + schema: list[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[a2, b2], + ) + ] + + def mock_input(value: str): + """ + We need to dynamically set the return value based on what's passed + in. + """ + return ( + "y" + if value + == "Did you rename the `a1` column to `a2` on the `Band` table? (y/N)" # noqa: E501 + else "n" + ) + + input.side_effect = mock_input + + schema_differ = SchemaDiffer( + schema=schema, schema_snapshot=schema_snapshot + ) + self.assertEqual( + schema_differ.add_columns.statements, + [ + "manager.add_column(table_class_name='Band', tablename='band', column_name='b2', db_column_name='b2', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False}, schema=None)" # noqa: E501 + ], + ) + self.assertEqual( + schema_differ.drop_columns.statements, + [ + "manager.drop_column(table_class_name='Band', tablename='band', column_name='b1', db_column_name='b1', schema=None)" # noqa: E501 + ], + ) + self.assertEqual( + schema_differ.rename_columns.statements, + [ + "manager.rename_column(table_class_name='Band', tablename='band', old_column_name='a1', new_column_name='a2', old_db_column_name='a1', new_db_column_name='a2', schema=None)", # noqa: E501 + ], + ) - self.assertTrue(len(schema_differ.rename_columns.statements) == 1) self.assertEqual( - schema_differ.rename_columns.statements[0], - "manager.rename_column(table_class_name='Band', tablename='band', old_column_name='title', new_column_name='name')", # noqa + input.call_args_list, + [ + call( + "Did you rename the `a1` column to `a2` on the `Band` table? (y/N)" # noqa: E501 + ), + call( + "Did you rename the `b1` column to `b2` on the `Band` table? (y/N)" # noqa: E501 + ), + ], ) - def test_alter_column_precision(self): + def test_alter_column_precision(self) -> None: price_1 = Numeric(digits=(4, 2)) price_1._meta.name = "price" price_2 = Numeric(digits=(5, 2)) price_2._meta.name = "price" - schema: t.List[DiffableTable] = [ + schema: list[DiffableTable] = [ DiffableTable( class_name="Ticket", tablename="ticket", columns=[price_1], ) ] - schema_snapshot: t.List[DiffableTable] = [ + schema_snapshot: list[DiffableTable] = [ DiffableTable( class_name="Ticket", tablename="ticket", @@ -221,8 +446,97 @@ def test_alter_column_precision(self): self.assertTrue(len(schema_differ.alter_columns.statements) == 1) self.assertEqual( schema_differ.alter_columns.statements[0], - "manager.alter_column(table_class_name='Ticket', tablename='ticket', column_name='price', params={'digits': (4, 2)}, old_params={'digits': (5, 2)}, column_class=Numeric, old_column_class=Numeric)", # noqa + "manager.alter_column(table_class_name='Ticket', tablename='ticket', column_name='price', db_column_name='price', params={'digits': (4, 2)}, old_params={'digits': (5, 2)}, column_class=Numeric, old_column_class=Numeric, schema=None)", # noqa + ) + + def test_db_column_name(self) -> None: + """ + Make sure alter statements use the ``db_column_name`` if provided. + + https://github.com/piccolo-orm/piccolo/issues/513 + + """ + price_1 = Numeric(digits=(4, 2), db_column_name="custom") + price_1._meta.name = "price" + + price_2 = Numeric(digits=(5, 2), db_column_name="custom") + price_2._meta.name = "price" + + schema: list[DiffableTable] = [ + DiffableTable( + class_name="Ticket", + tablename="ticket", + columns=[price_1], + ) + ] + schema_snapshot: list[DiffableTable] = [ + DiffableTable( + class_name="Ticket", + tablename="ticket", + columns=[price_2], + ) + ] + + schema_differ = SchemaDiffer( + schema=schema, schema_snapshot=schema_snapshot, auto_input="y" + ) + + self.assertTrue(len(schema_differ.alter_columns.statements) == 1) + self.assertEqual( + schema_differ.alter_columns.statements[0], + "manager.alter_column(table_class_name='Ticket', tablename='ticket', column_name='price', db_column_name='custom', params={'digits': (4, 2)}, old_params={'digits': (5, 2)}, column_class=Numeric, old_column_class=Numeric, schema=None)", # noqa ) def test_alter_default(self): pass + + +class TestRenameTableCollection(TestCase): + collection = RenameTableCollection( + rename_tables=[ + RenameTable( + old_class_name="Manager", + old_tablename="manager", + new_class_name="Manager1", + new_tablename="manager_1", + ) + ] + ) + + def test_was_renamed_from(self): + self.assertTrue( + self.collection.was_renamed_from(old_class_name="Manager") + ) + self.assertFalse( + self.collection.was_renamed_from(old_class_name="Band") + ) + + def test_renamed_from(self): + self.assertEqual( + self.collection.renamed_from(new_class_name="Manager1"), "Manager" + ) + self.assertIsNone( + self.collection.renamed_from(new_class_name="Band"), + ) + + +class TestRenameColumnCollection(TestCase): + def test_for_table_class_name(self): + rename_column = RenameColumn( + table_class_name="Manager", + tablename="manager", + old_column_name="name", + new_column_name="full_name", + old_db_column_name="name", + new_db_column_name="full_name", + ) + + collection = RenameColumnCollection(rename_columns=[rename_column]) + + self.assertListEqual( + collection.for_table_class_name(table_class_name="Manager"), + [rename_column], + ) + self.assertListEqual( + collection.for_table_class_name(table_class_name="Band"), [] + ) diff --git a/tests/apps/migrations/auto/test_serialisation.py b/tests/apps/migrations/auto/test_serialisation.py index 8cc0639cf..7171af3ae 100644 --- a/tests/apps/migrations/auto/test_serialisation.py +++ b/tests/apps/migrations/auto/test_serialisation.py @@ -1,7 +1,19 @@ +import decimal +import uuid +import warnings from enum import Enum from unittest import TestCase -from piccolo.apps.migrations.auto.serialisation import serialise_params +import pytest + +from piccolo.apps.migrations.auto.serialisation import ( + CanConflictWithGlobalNames, + Import, + UniqueGlobalNameConflictWarning, + UniqueGlobalNames, + UniqueGlobalNamesMeta, + serialise_params, +) from piccolo.columns.base import OnDelete from piccolo.columns.choices import Choice from piccolo.columns.column_types import Varchar @@ -9,6 +21,151 @@ from piccolo.columns.reference import LazyTableReference +class TestUniqueGlobalNamesMeta: + def test_duplicate_class_attribute_values_raises_error(self): + with pytest.raises(ValueError): + + class IncorrectUniqueGlobalNames(metaclass=UniqueGlobalNamesMeta): + A = "duplicate" + B = "duplicate" + + +class TestUniqueGlobals: + def test_contains_column_types(self): + assert getattr(UniqueGlobalNames, "COLUMN_VARCHAR", "Varchar") + assert getattr(UniqueGlobalNames, "COLUMN_SECRET", "Secret") + assert getattr(UniqueGlobalNames, "COLUMN_TEXT", "Text") + assert getattr(UniqueGlobalNames, "COLUMN_UUID", "UUID") + assert getattr(UniqueGlobalNames, "COLUMN_INTEGER", "Integer") + assert getattr(UniqueGlobalNames, "COLUMN_BIGINT", "BigInt") + assert getattr(UniqueGlobalNames, "COLUMN_SMALLINT", "SmallInt") + assert getattr(UniqueGlobalNames, "COLUMN_SERIAL", "Serial") + assert getattr(UniqueGlobalNames, "COLUMN_BIGSERIAL", "BigSerial") + assert getattr(UniqueGlobalNames, "COLUMN_PRIMARYKEY", "PrimaryKey") + assert getattr(UniqueGlobalNames, "COLUMN_TIMESTAMP", "Timestamp") + assert getattr(UniqueGlobalNames, "COLUMN_TIMESTAMPZ", "Timestampz") + assert getattr(UniqueGlobalNames, "COLUMN_DATE", "Date") + assert getattr(UniqueGlobalNames, "COLUMN_TIME", "Time") + assert getattr(UniqueGlobalNames, "COLUMN_INTERVAL", "Interval") + assert getattr(UniqueGlobalNames, "COLUMN_BOOLEAN", "Boolean") + assert getattr(UniqueGlobalNames, "COLUMN_NUMERIC", "Numeric") + assert getattr(UniqueGlobalNames, "COLUMN_DECIMAL", "Decimal") + assert getattr(UniqueGlobalNames, "COLUMN_FLOAT", "Float") + assert getattr( + UniqueGlobalNames, "COLUMN_DOUBLEPERCISION", "DoublePrecision" + ) + assert getattr(UniqueGlobalNames, "COLUMN_FOREIGNKEY", "ForeignKey") + assert getattr(UniqueGlobalNames, "COLUMN_JSON", "JSON") + assert getattr(UniqueGlobalNames, "COLUMN_BYTEA", "Bytea") + assert getattr(UniqueGlobalNames, "COLUMN_BLOB", "Blob") + assert getattr(UniqueGlobalNames, "COLUMN_ARRAY", "Array") + + def test_warn_if_is_conflicting_name(self): + with warnings.catch_warnings() as recorded_warnings: + warnings.simplefilter("error") + UniqueGlobalNames.warn_if_is_conflicting_name( + "SuperMassiveBlackHole" + ) + + with pytest.warns( + UniqueGlobalNameConflictWarning + ) as recorded_warnings: + UniqueGlobalNames.warn_if_is_conflicting_name("Varchar") + + if len(recorded_warnings) != 1: + pytest.fail("Expected 1 warning!") + + def test_is_conflicting_name(self): + assert ( + UniqueGlobalNames.is_conflicting_name("SuperMassiveBlackHole") + is False + ) + assert UniqueGlobalNames.is_conflicting_name("Varchar") is True + + def test_warn_if_are_conflicting_objects(self): + class ConflictingCls1(CanConflictWithGlobalNames): + def warn_if_is_conflicting_with_global_name(self): + pass + + class ConflictingCls2(CanConflictWithGlobalNames): + def warn_if_is_conflicting_with_global_name(self): + pass + + class ConflictingCls3(CanConflictWithGlobalNames): + def warn_if_is_conflicting_with_global_name(self): + warnings.warn("test", UniqueGlobalNameConflictWarning) + + with warnings.catch_warnings() as recorded_warnings: + warnings.simplefilter("error") + UniqueGlobalNames.warn_if_are_conflicting_objects( + [ConflictingCls1(), ConflictingCls2()] + ) + + with pytest.warns( + UniqueGlobalNameConflictWarning + ) as recorded_warnings: + UniqueGlobalNames.warn_if_are_conflicting_objects( + [ConflictingCls2(), ConflictingCls3()] + ) + + if len(recorded_warnings) != 1: + pytest.fail("Expected 1 warning!") + + +class TestImport: + def test_with_module_only(self): + assert repr(Import(module="a.b.c")) == "import a.b.c" + + def test_with_module_and_target(self): + assert repr(Import(module="a.b", target="c")) == "from a.b import c" + + def test_warn_if_is_conflicting_with_global_name_with_module_only(self): + with warnings.catch_warnings() as recorded_warnings: + warnings.simplefilter("error") + Import(module="a.b.c").warn_if_is_conflicting_with_global_name() + + with pytest.warns( + UniqueGlobalNameConflictWarning + ) as recorded_warnings: + Import(module="Varchar").warn_if_is_conflicting_with_global_name() + + if len(recorded_warnings) != 1: + pytest.fail("Expected 1 warning!") + + with warnings.catch_warnings() as recorded_warnings: + warnings.simplefilter("error") + Import( + module="Varchar", expect_conflict_with_global_name="Varchar" + ).warn_if_is_conflicting_with_global_name() + + def test_warn_if_is_conflicting_with_global_name_with_module_and_target( + self, + ): + with warnings.catch_warnings() as recorded_warnings: + warnings.simplefilter("error") + Import( + module="a.b", target="c" + ).warn_if_is_conflicting_with_global_name() + + with pytest.warns( + UniqueGlobalNameConflictWarning + ) as recorded_warnings: + Import( + module="a.b", target="Varchar" + ).warn_if_is_conflicting_with_global_name() + + if len(recorded_warnings) != 1: + pytest.fail("Expected 1 warning!") + + with warnings.catch_warnings() as recorded_warnings: + warnings.simplefilter("error") + Import( + module="a.b", + target="Varchar", + expect_conflict_with_global_name="Varchar", + ).warn_if_is_conflicting_with_global_name() + + def example_function(): pass @@ -34,18 +191,28 @@ def test_timestamp(self): ) def test_uuid(self): + serialised = serialise_params(params={"default": uuid.UUID(int=4)}) + assert ( + repr(serialised.params["default"]) + == 'uuid.UUID("00000000-0000-0000-0000-000000000004")' + ) + serialised = serialise_params(params={"default": UUID4()}) self.assertTrue(serialised.params["default"].__repr__() == "UUID4()") + def test_decimal(self): + serialised = serialise_params( + params={"default": decimal.Decimal("1.2")} + ) + assert repr(serialised.params["default"]) == 'decimal.Decimal("1.2")' + def test_lazy_table_reference(self): # These are equivalent: references_list = [ - LazyTableReference( - table_class_name="Manager", app_name="example_app" - ), + LazyTableReference(table_class_name="Manager", app_name="music"), LazyTableReference( table_class_name="Manager", - module_path="tests.example_app.tables", + module_path="tests.example_apps.music.tables", ), ] @@ -54,21 +221,28 @@ def test_lazy_table_reference(self): self.assertTrue( serialised.params["references"].__repr__() == "Manager" ) + # sorted extra_imports for consistency between tests + sorted_extra_imports = sorted(serialised.extra_imports) - self.assertTrue(len(serialised.extra_imports) == 1) + self.assertTrue(len(serialised.extra_imports) == 2) + self.assertEqual( + sorted_extra_imports[0].__str__(), + "from piccolo.columns.column_types import Serial", + ) self.assertEqual( - serialised.extra_imports[0].__str__(), + sorted_extra_imports[1].__str__(), "from piccolo.table import Table", ) self.assertTrue(len(serialised.extra_definitions) == 1) + self.assertEqual( serialised.extra_definitions[0].__str__(), ( - 'class Manager(Table, tablename="manager"): ' + 'class Manager(Table, tablename="manager", schema=None): ' "id = Serial(null=False, primary_key=True, unique=False, " "index=False, index_method=IndexMethod.btree, " - "choices=None)" + "choices=None, db_column_name='id', secret=False)" ), ) @@ -119,7 +293,7 @@ def test_column_instance(self): self.assertEqual( serialised.params["base_column"].__repr__(), - "Varchar(length=255, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None)", # noqa: E501 + "Varchar(length=255, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name=None, secret=False)", # noqa: E501 ) self.assertEqual( diff --git a/tests/apps/migrations/commands/test_base.py b/tests/apps/migrations/commands/test_base.py index fb80f2c28..25d1b9494 100644 --- a/tests/apps/migrations/commands/test_base.py +++ b/tests/apps/migrations/commands/test_base.py @@ -43,9 +43,9 @@ def test_get_migration_modules(self): class TestGetTableFromSnapshot(TestCase): @patch.object(BaseMigrationManager, "get_app_config") - def test_get_table_from_snaphot(self, get_app_config: MagicMock): + def test_get_table_from_snapshot(self, get_app_config: MagicMock): """ - Test the get_table_from_snaphot method. + Test the get_table_from_snapshot method. """ get_app_config.return_value = AppConfig( app_name="music", @@ -55,7 +55,7 @@ def test_get_table_from_snaphot(self, get_app_config: MagicMock): ) table = run_sync( - BaseMigrationManager().get_table_from_snaphot( + BaseMigrationManager().get_table_from_snapshot( app_name="music", table_class_name="Band" ) ) diff --git a/tests/apps/migrations/commands/test_check.py b/tests/apps/migrations/commands/test_check.py index 397c28336..b88cbc378 100644 --- a/tests/apps/migrations/commands/test_check.py +++ b/tests/apps/migrations/commands/test_check.py @@ -1,12 +1,17 @@ -from unittest import TestCase +from unittest import IsolatedAsyncioTestCase from unittest.mock import MagicMock, patch from piccolo.apps.migrations.commands.check import CheckMigrationManager, check +from piccolo.apps.migrations.tables import Migration from piccolo.conf.apps import AppRegistry from piccolo.utils.sync import run_sync -class TestCheckMigrationCommand(TestCase): +class TestCheckMigrationCommand(IsolatedAsyncioTestCase): + + async def asyncTearDown(self): + await Migration.alter().drop_table(if_exists=True) + @patch.object( CheckMigrationManager, "get_app_registry", diff --git a/tests/apps/migrations/commands/test_clean.py b/tests/apps/migrations/commands/test_clean.py index 37b3ae5a2..a6ff8c2b4 100644 --- a/tests/apps/migrations/commands/test_clean.py +++ b/tests/apps/migrations/commands/test_clean.py @@ -20,14 +20,14 @@ def test_clean(self): migration_ids = real_migration_ids + [orphaned_migration_id] Migration.insert( - *[Migration(name=i, app_name="example_app") for i in migration_ids] + *[Migration(name=i, app_name="music") for i in migration_ids] ).run_sync() - run_sync(clean(app_name="example_app", auto_agree=True)) + run_sync(clean(app_name="music", auto_agree=True)) remaining_rows = ( Migration.select(Migration.name) - .where(Migration.app_name == "example_app") + .where(Migration.app_name == "music") .output(as_list=True) .order_by(Migration.name) .run_sync() diff --git a/tests/apps/migrations/commands/test_forwards_backwards.py b/tests/apps/migrations/commands/test_forwards_backwards.py index dd5c8e5b3..1ccc5bce7 100644 --- a/tests/apps/migrations/commands/test_forwards_backwards.py +++ b/tests/apps/migrations/commands/test_forwards_backwards.py @@ -1,7 +1,7 @@ from __future__ import annotations import sys -import typing as t +from typing import TYPE_CHECKING from unittest import TestCase from unittest.mock import MagicMock, call, patch @@ -9,10 +9,11 @@ from piccolo.apps.migrations.commands.forwards import forwards from piccolo.apps.migrations.tables import Migration from piccolo.utils.sync import run_sync -from tests.base import postgres_only -from tests.example_app.tables import ( +from tests.base import engines_only +from tests.example_apps.music.tables import ( Band, Concert, + Instrument, Manager, Poster, RecordingStudio, @@ -21,11 +22,10 @@ Venue, ) -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from piccolo.table import Table - -TABLE_CLASSES: t.List[t.Type[Table]] = [ +TABLE_CLASSES: list[type[Table]] = [ Manager, Band, Venue, @@ -34,10 +34,11 @@ Poster, Shirt, RecordingStudio, + Instrument, ] -@postgres_only +@engines_only("postgres", "cockroach") class TestForwardsBackwards(TestCase): """ Test the forwards and backwards migration commands. @@ -47,12 +48,13 @@ def test_forwards_backwards_all_migrations(self): """ Test running all of the migrations forwards, then backwards. """ - for app_name in ("example_app", "all"): + for app_name in ("music", "all"): run_sync(forwards(app_name=app_name, migration_id="all")) # Check the tables exist for table_class in TABLE_CLASSES: self.assertTrue(table_class.table_exists().run_sync()) + self.assertNotEqual(Migration.count().run_sync(), 0) run_sync( backwards( @@ -63,25 +65,37 @@ def test_forwards_backwards_all_migrations(self): # Check the tables don't exist for table_class in TABLE_CLASSES: self.assertTrue(not table_class.table_exists().run_sync()) + self.assertEqual(Migration.count().run_sync(), 0) + # Preview + run_sync( + forwards(app_name=app_name, migration_id="all", preview=True) + ) + for table_class in TABLE_CLASSES: + self.assertTrue(not table_class.table_exists().run_sync()) + self.assertEqual(Migration.count().run_sync(), 0) def test_forwards_backwards_single_migration(self): """ Test running a single migrations forwards, then backwards. """ - for migration_id in ["1", "2020-12-17T18:44:30"]: - run_sync( - forwards(app_name="example_app", migration_id=migration_id) - ) + table_classes = [Band, Manager] - table_classes = [Band, Manager] + for migration_id in ["1", "2020-12-17T18:44:30"]: + run_sync(forwards(app_name="music", migration_id=migration_id)) # Check the tables exist for table_class in table_classes: self.assertTrue(table_class.table_exists().run_sync()) + self.assertTrue( + Migration.exists() + .where(Migration.name == "2020-12-17T18:44:30") + .run_sync() + ) + run_sync( backwards( - app_name="example_app", + app_name="music", migration_id=migration_id, auto_agree=True, ) @@ -90,6 +104,25 @@ def test_forwards_backwards_single_migration(self): # Check the tables don't exist for table_class in table_classes: self.assertTrue(not table_class.table_exists().run_sync()) + self.assertFalse( + Migration.exists() + .where(Migration.name == "2020-12-17T18:44:30") + .run_sync() + ) + + # Preview + run_sync( + forwards( + app_name="music", migration_id=migration_id, preview=True + ) + ) + for table_class in table_classes: + self.assertTrue(not table_class.table_exists().run_sync()) + self.assertFalse( + Migration.exists() + .where(Migration.name == "2020-12-17T18:44:30") + .run_sync() + ) @patch("piccolo.apps.migrations.commands.forwards.print") def test_forwards_unknown_migration(self, print_: MagicMock): @@ -98,9 +131,7 @@ def test_forwards_unknown_migration(self, print_: MagicMock): """ with self.assertRaises(SystemExit): run_sync( - forwards( - app_name="example_app", migration_id="migration-12345" - ) + forwards(app_name="music", migration_id="migration-12345") ) self.assertTrue( @@ -113,12 +144,12 @@ def test_backwards_unknown_migration(self, print_: MagicMock): """ Test running an unknown migrations backwards. """ - run_sync(forwards(app_name="example_app", migration_id="all")) + run_sync(forwards(app_name="music", migration_id="all")) with self.assertRaises(SystemExit): run_sync( backwards( - app_name="example_app", + app_name="music", migration_id="migration-12345", auto_agree=True, ) @@ -137,32 +168,33 @@ def test_backwards_no_migrations(self, print_: MagicMock): """ run_sync( backwards( - app_name="example_app", + app_name="music", migration_id="2020-12-17T18:44:30", auto_agree=True, ) ) - self.assertTrue(call("No migrations to reverse!") in print_.mock_calls) + self.assertTrue( + call("🏁 No migrations to reverse!") in print_.mock_calls + ) @patch("piccolo.apps.migrations.commands.forwards.print") def test_forwards_no_migrations(self, print_: MagicMock): """ Test running the migrations if they've already run. """ - run_sync(forwards(app_name="example_app", migration_id="all")) - run_sync(forwards(app_name="example_app", migration_id="all")) + run_sync(forwards(app_name="music", migration_id="all")) + run_sync(forwards(app_name="music", migration_id="all")) self.assertTrue( - print_.mock_calls[-1] == call("No migrations left to run!") + print_.mock_calls[-1] == call("🏁 No migrations need to be run") ) + @engines_only("postgres") def test_forwards_fake(self): """ - Test running the migrations if they've already run. + Make sure migrations can be faked on the command line. """ - run_sync( - forwards(app_name="example_app", migration_id="all", fake=True) - ) + run_sync(forwards(app_name="music", migration_id="all", fake=True)) for table_class in TABLE_CLASSES: self.assertTrue(not table_class.table_exists().run_sync()) @@ -179,9 +211,43 @@ def test_forwards_fake(self): "2020-12-17T18:44:39", "2020-12-17T18:44:44", "2021-07-25T22:38:48:009306", + "2021-09-06T13:58:23:024723", + "2021-11-13T14:01:46:114725", + "2024-05-28T23:15:41:018844", + "2024-06-19T18:11:05:793132", ], ) + @engines_only("postgres") + @patch("piccolo.apps.migrations.commands.forwards.print") + def test_hardcoded_fake_migrations(self, print_: MagicMock): + """ + Make sure that migrations that have been hardcoded as fake aren't + executed, even without the ``--fake`` command line flag. + + See tests/example_apps/music/piccolo_migrations/music_2024_06_19t18_11_05_793132.py + + """ # noqa: E501 + run_sync(forwards(app_name="music", migration_id="all")) + + # The migration which is hardcoded as fake: + migration_name = "2024-06-19T18:11:05:793132" + + self.assertTrue( + Migration.exists() + .where(Migration.name == migration_name) + .run_sync() + ) + + self.assertNotIn( + call("Running fake migration"), + print_.mock_calls, + ) + self.assertIn( + call(f"- {migration_name}: faked! ⏭️"), + print_.mock_calls, + ) + def tearDown(self): for table_class in TABLE_CLASSES + [Migration]: table_class.alter().drop_table( diff --git a/tests/apps/migrations/commands/test_migrations/2020-03-31T20-38-22.py b/tests/apps/migrations/commands/test_migrations/2020-03-31T20-38-22.py index 347020d44..bcdf8e8e2 100644 --- a/tests/apps/migrations/commands/test_migrations/2020-03-31T20-38-22.py +++ b/tests/apps/migrations/commands/test_migrations/2020-03-31T20-38-22.py @@ -15,8 +15,7 @@ async def forwards(): "length": 150, "default": "", "null": True, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, diff --git a/tests/apps/migrations/commands/test_new.py b/tests/apps/migrations/commands/test_new.py index 5d7fa266f..da47877c1 100644 --- a/tests/apps/migrations/commands/test_new.py +++ b/tests/apps/migrations/commands/test_new.py @@ -1,3 +1,4 @@ +import datetime import os import shutil import tempfile @@ -7,16 +8,17 @@ from piccolo.apps.migrations.commands.new import ( BaseMigrationManager, _create_new_migration, + _generate_migration_meta, new, ) from piccolo.conf.apps import AppConfig from piccolo.utils.sync import run_sync -from tests.base import postgres_only -from tests.example_app.tables import Manager +from tests.base import engines_only +from tests.example_apps.music.tables import Manager class TestNewMigrationCommand(TestCase): - def test_create_new_migration(self): + def test_manual(self): """ Create a manual migration (i.e. non-auto). """ @@ -42,17 +44,97 @@ def test_create_new_migration(self): self.assertTrue(len(migration_modules.keys()) == 1) - @postgres_only + @engines_only("postgres") @patch("piccolo.apps.migrations.commands.new.print") - def test_new_command(self, print_: MagicMock): + def test_auto(self, print_: MagicMock): """ Call the command, when no migration changes are needed. """ - with self.assertRaises(SystemExit) as manager: - run_sync(new(app_name="example_app", auto=True)) + run_sync(new(app_name="music", auto=True)) - self.assertEqual(manager.exception.code, 0) + self.assertListEqual( + print_.call_args_list, + [ + call("🚀 Creating new migration ..."), + call("🏁 No changes detected."), + call("\n✅ Finished\n"), + ], + ) + + @engines_only("postgres") + @patch("piccolo.apps.migrations.commands.new.print") + def test_auto_all(self, print_: MagicMock): + """ + Try auto migrating all apps. + """ + run_sync(new(app_name="all", auto=True)) + self.assertListEqual( + print_.call_args_list, + [ + call("🚀 Creating new migration ..."), + call("🏁 No changes detected."), + call("🚀 Creating new migration ..."), + call("🏁 No changes detected."), + call("\n✅ Finished\n"), + ], + ) - self.assertTrue( - print_.mock_calls[-1] == call("No changes detected - exiting.") + @engines_only("postgres") + def test_auto_all_error(self): + """ + Call the command, when no migration changes are needed. + """ + with self.assertRaises(ValueError) as manager: + run_sync(new(app_name="all", auto=False)) + + self.assertEqual( + manager.exception.__str__(), + "Only use `--app_name=all` in conjunction with `--auto`.", + ) + + +class TestGenerateMigrationMeta(TestCase): + @patch("piccolo.apps.migrations.commands.new.now") + def test_filename(self, now: MagicMock): + now.return_value = datetime.datetime( + year=2022, + month=1, + day=10, + hour=7, + minute=15, + second=20, + microsecond=3000, + ) + + # Try with an app name which already contains valid characters for a + # Python module. + migration_meta = _generate_migration_meta( + app_config=AppConfig( + app_name="app_name", + migrations_folder_path="/tmp/", + ) + ) + self.assertEqual( + migration_meta.migration_filename, + "app_name_2022_01_10t07_15_20_003000", + ) + self.assertEqual( + migration_meta.migration_path, + "/tmp/app_name_2022_01_10t07_15_20_003000.py", + ) + + # Try with an app name with invalid characters for a Python module. + migration_meta = _generate_migration_meta( + app_config=AppConfig( + app_name="App-Name!", + migrations_folder_path="/tmp/", + ) + ) + self.assertEqual( + migration_meta.migration_filename, + "app_name_2022_01_10t07_15_20_003000", + ) + self.assertEqual( + migration_meta.migration_path, + "/tmp/app_name_2022_01_10t07_15_20_003000.py", ) diff --git a/tests/apps/schema/__init__.py b/tests/apps/schema/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/apps/schema/commands/test_generate.py b/tests/apps/schema/commands/test_generate.py new file mode 100644 index 000000000..5168ab060 --- /dev/null +++ b/tests/apps/schema/commands/test_generate.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +import ast +import asyncio +from typing import cast +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from piccolo.apps.schema.commands.exceptions import GenerateError +from piccolo.apps.schema.commands.generate import ( + OutputSchema, + generate, + get_output_schema, +) +from piccolo.columns.base import Column +from piccolo.columns.column_types import ( + ForeignKey, + Integer, + Timestamp, + Varchar, +) +from piccolo.columns.indexes import IndexMethod +from piccolo.schema import SchemaManager +from piccolo.table import Table, create_db_tables_sync +from piccolo.utils.sync import run_sync +from tests.base import AsyncMock, engines_only, engines_skip +from tests.example_apps.mega.tables import MegaTable, SmallTable + + +@engines_only("postgres", "cockroach") +class TestGenerate(TestCase): + def setUp(self): + for table_class in (SmallTable, MegaTable): + table_class.create_table().run_sync() + + def tearDown(self): + for table_class in (MegaTable, SmallTable): + table_class.alter().drop_table().run_sync() + + def _compare_table_columns( + self, table_1: type[Table], table_2: type[Table] + ): + """ + Make sure that for each column in table_1, there is a corresponding + column in table_2 of the same type. + """ + column_names = [ + column._meta.name for column in table_1._meta.non_default_columns + ] + for column_name in column_names: + col_1 = table_1._meta.get_column_by_name(column_name) + col_2 = table_2._meta.get_column_by_name(column_name) + + # Make sure they're the same type + self.assertEqual(type(col_1), type(col_2)) + + # Make sure they're both nullable or not + self.assertEqual(col_1._meta.null, col_2._meta.null) + + # Make sure the max length is the same + if isinstance(col_1, Varchar) and isinstance(col_2, Varchar): + self.assertEqual(col_1.length, col_2.length) + + # Make sure the unique constraint is the same + self.assertEqual(col_1._meta.unique, col_2._meta.unique) + + def test_get_output_schema(self) -> None: + """ + Make sure that the a Piccolo schema can be generated from the database. + """ + output_schema: OutputSchema = run_sync(get_output_schema()) + + self.assertTrue(len(output_schema.warnings) == 0) + self.assertTrue(len(output_schema.tables) == 2) + self.assertTrue(len(output_schema.imports) > 0) + + MegaTable_ = output_schema.get_table_with_name("MegaTable") + assert MegaTable_ is not None + self._compare_table_columns(MegaTable, MegaTable_) + + SmallTable_ = output_schema.get_table_with_name("SmallTable") + assert SmallTable_ is not None + self._compare_table_columns(SmallTable, SmallTable_) + + @patch("piccolo.apps.schema.commands.generate.print") + def test_generate_command(self, print_: MagicMock): + """ + Test the main generate command runs without errors. + """ + run_sync(generate()) + file_contents = print_.call_args[0][0] + + # Make sure the output is valid Python code (will raise a SyntaxError + # exception otherwise). + ast.parse(file_contents) + + # Cockroach throws FeatureNotSupportedError, which does not pass this test. + @engines_skip("cockroach") + def test_unknown_column_type(self) -> None: + """ + Make sure unknown column types are handled gracefully. + """ + + class Box(Column): + """ + A column type which isn't supported by Piccolo officially yet. + """ + + pass + + MegaTable.alter().add_column("my_column", Box()).run_sync() + + output_schema: OutputSchema = run_sync(get_output_schema()) + + # Make sure there's a warning. + self.assertEqual( + output_schema.warnings, ["mega_table.my_column ['box']"] + ) + + # Make sure the column type of the generated table is just ``Column``. + for table in output_schema.tables: + if table.__name__ == "MegaTable": + self.assertEqual( + output_schema.tables[1] + ._meta.get_column_by_name("my_column") + .__class__.__name__, + "Column", + ) + + def test_generate_required_tables(self) -> None: + """ + Make sure only tables passed to `tablenames` are created. + """ + output_schema: OutputSchema = run_sync( + get_output_schema(include=[SmallTable._meta.tablename]) + ) + self.assertEqual(len(output_schema.tables), 1) + SmallTable_ = output_schema.get_table_with_name("SmallTable") + assert SmallTable_ is not None + self._compare_table_columns(SmallTable, SmallTable_) + + def test_exclude_table(self) -> None: + """ + Make sure exclude works. + """ + output_schema: OutputSchema = run_sync( + get_output_schema(exclude=[MegaTable._meta.tablename]) + ) + self.assertEqual(len(output_schema.tables), 1) + SmallTable_ = output_schema.get_table_with_name("SmallTable") + assert SmallTable_ is not None + self._compare_table_columns(SmallTable, SmallTable_) + + @engines_skip("cockroach") + def test_self_referencing_fk(self) -> None: + """ + Make sure self-referencing foreign keys are handled correctly. + """ + + MegaTable.alter().add_column( + "self_referencing_fk", ForeignKey("self") + ).run_sync() + + output_schema: OutputSchema = run_sync(get_output_schema()) + + # Make sure the 'references' value of the generated column is "self". + for table in output_schema.tables: + if table.__name__ == "MegaTable": + column = cast( + ForeignKey, + output_schema.tables[1]._meta.get_column_by_name( + "self_referencing_fk" + ), + ) + + self.assertEqual( + column._foreign_key_meta.resolved_references._meta.tablename, # noqa: E501 + MegaTable._meta.tablename, + ) + self.assertEqual(column._meta.params["references"], "self") + + +############################################################################### + + +class Concert(Table): + name = Varchar(index=True, index_method=IndexMethod.hash) + time = Timestamp( + index=True + ) # Testing a column with the same name as a Postgres data type. + capacity = Integer(index=False) + + +@engines_only("postgres") +class TestGenerateWithIndexes(TestCase): + def setUp(self): + Concert.create_table().run_sync() + + def tearDown(self): + Concert.alter().drop_table(if_exists=True).run_sync() + + def test_index(self) -> None: + """ + Make sure that a table with an index is reflected correctly. + """ + output_schema: OutputSchema = run_sync(get_output_schema()) + Concert_ = output_schema.tables[0] + + name_column = Concert_._meta.get_column_by_name("name") + self.assertTrue(name_column._meta.index) + self.assertEqual(name_column._meta.index_method, IndexMethod.hash) + + time_column = Concert_._meta.get_column_by_name("time") + self.assertTrue(time_column._meta.index) + self.assertEqual(time_column._meta.index_method, IndexMethod.btree) + + capacity_column = Concert_._meta.get_column_by_name("capacity") + self.assertEqual(capacity_column._meta.index, False) + self.assertEqual(capacity_column._meta.index_method, IndexMethod.btree) + + +############################################################################### + + +class Publication(Table, tablename="publication", schema="schema_2"): + name = Varchar(length=50) + + +class Writer(Table, tablename="writer", schema="schema_1"): + name = Varchar(length=50) + publication = ForeignKey(Publication, null=True) + + +class Book(Table): + name = Varchar(length=50) + writer = ForeignKey(Writer, null=True) + popularity = Integer(default=0) + + +@engines_only("postgres") +class TestGenerateWithSchema(TestCase): + tables = [Publication, Writer, Book] + + schema_manager = SchemaManager() + + def setUp(self) -> None: + for schema_name in ("schema_1", "schema_2"): + self.schema_manager.create_schema( + schema_name=schema_name, if_not_exists=True + ).run_sync() + + create_db_tables_sync(*self.tables) + + def tearDown(self) -> None: + Book.alter().drop_table().run_sync() + + for schema_name in ("schema_1", "schema_2"): + self.schema_manager.drop_schema( + schema_name=schema_name, if_exists=True, cascade=True + ).run_sync() + + def test_reference_to_another_schema(self) -> None: + output_schema: OutputSchema = run_sync(get_output_schema()) + self.assertEqual(len(output_schema.tables), 3) + publication = output_schema.tables[0] + writer = output_schema.tables[1] + book = output_schema.tables[2] + # Make sure referenced tables have been created + self.assertEqual( + Publication._meta.tablename, publication._meta.tablename + ) + self.assertEqual(Writer._meta.tablename, writer._meta.tablename) + + # Make sure foreign key values are correct. + self.assertEqual( + writer._meta.get_column_by_name("publication"), publication + ) + self.assertEqual(book._meta.get_column_by_name("writer"), writer) + + +@engines_only("postgres", "cockroach") +class TestGenerateWithException(TestCase): + def setUp(self): + for table_class in (SmallTable, MegaTable): + table_class.create_table().run_sync() + + def tearDown(self): + for table_class in (MegaTable, SmallTable): + table_class.alter().drop_table(if_exists=True).run_sync() + + @patch( + "piccolo.apps.schema.commands.generate.create_table_class_from_db", + new_callable=AsyncMock, + ) + def test_exception(self, create_table_class_from_db_mock: AsyncMock): + """ + Make sure that a GenerateError exception is raised with all the + exceptions gathered. + """ + create_table_class_from_db_mock.side_effect = [ + ValueError("Test"), + TypeError("Test"), + ] + + # Make sure the exception is raised. + with self.assertRaises(GenerateError) as e: + asyncio.run(get_output_schema()) + + # Make sure the exception contains the correct number of errors. + self.assertEqual(len(e.exception.args[0]), 2) + # assert that the two exceptions are ValueError and TypeError + exception_types = [type(e) for e in e.exception.args[0]] + self.assertIn(ValueError, exception_types) + self.assertIn(TypeError, exception_types) + + # Make sure the exception contains the correct error messages. + exception_messages = [str(e) for e in e.exception.args[0]] + self.assertIn( + "Exception occurred while generating `small_table` table: Test", + exception_messages, + ) + self.assertIn( + "Exception occurred while generating `mega_table` table: Test", + exception_messages, + ) diff --git a/tests/apps/schema/commands/test_graph.py b/tests/apps/schema/commands/test_graph.py new file mode 100644 index 000000000..16e541a43 --- /dev/null +++ b/tests/apps/schema/commands/test_graph.py @@ -0,0 +1,47 @@ +import os +import tempfile +import uuid +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from piccolo.apps.schema.commands.graph import graph + + +class TestGraph(TestCase): + def _verify_contents(self, file_contents: str): + """ + Make sure the contents of the file are correct. + """ + # Make sure no extra content was output at the start. + self.assertTrue(file_contents.startswith("digraph model_graph")) + + # Make sure the tables are present + self.assertTrue("TABLE_Band [label" in file_contents) + self.assertTrue("TABLE_Manager [label" in file_contents) + + # Make sure a relation is present + self.assertTrue("TABLE_Concert -> TABLE_Band" in file_contents) + + @patch("piccolo.apps.schema.commands.graph.print") + def test_graph(self, print_: MagicMock): + """ + Make sure the file contents can be printed to stdout. + """ + graph() + file_contents = print_.call_args[0][0] + self._verify_contents(file_contents) + + def test_graph_to_file(self): + """ + Make sure the file contents can be written to disk. + """ + directory = tempfile.gettempdir() + path = os.path.join(directory, f"{uuid.uuid4()}.dot") + + graph(output=path) + + with open(path, "r") as f: + file_contents = f.read() + + self._verify_contents(file_contents) + os.unlink(path) diff --git a/tests/apps/shell/commands/test_run.py b/tests/apps/shell/commands/test_run.py index f2ec08a68..dd5122b41 100644 --- a/tests/apps/shell/commands/test_run.py +++ b/tests/apps/shell/commands/test_run.py @@ -17,15 +17,19 @@ def test_run(self, print_: MagicMock, start_ipython_shell: MagicMock): print_.mock_calls, [ call("-------"), - call("Importing example_app tables:"), + call("Importing music tables:"), call("- Band"), call("- Concert"), + call("- Instrument"), call("- Manager"), call("- Poster"), call("- RecordingStudio"), call("- Shirt"), call("- Ticket"), call("- Venue"), + call("Importing mega tables:"), + call("- MegaTable"), + call("- SmallTable"), call("-------"), ], ) diff --git a/tests/apps/sql_shell/commands/test_run.py b/tests/apps/sql_shell/commands/test_run.py index be4790ced..8d0c5689c 100644 --- a/tests/apps/sql_shell/commands/test_run.py +++ b/tests/apps/sql_shell/commands/test_run.py @@ -2,13 +2,37 @@ from unittest.mock import MagicMock, patch from piccolo.apps.sql_shell.commands.run import run +from tests.base import postgres_only, sqlite_only class TestRun(TestCase): + @postgres_only @patch("piccolo.apps.sql_shell.commands.run.subprocess") - def test_run(self, subprocess: MagicMock): + def test_psql(self, subprocess: MagicMock): """ - A simple test to make sure it executes without raising any exceptions. + Make sure psql was called correctly. """ run() self.assertTrue(subprocess.run.called) + + assert subprocess.run.call_args.args[0] == [ + "psql", + "-U", + "postgres", + "-h", + "localhost", + "-p", + "5432", + "piccolo", + ] + + @sqlite_only + @patch("piccolo.apps.sql_shell.commands.run.subprocess") + def test_sqlite3(self, subprocess: MagicMock): + """ + Make sure sqlite3 was called correctly. + """ + run() + self.assertTrue(subprocess.run.called) + + assert subprocess.run.call_args.args[0] == ["sqlite3", "test.sqlite"] diff --git a/tests/apps/tester/__init__.py b/tests/apps/tester/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/apps/tester/commands/test_run.py b/tests/apps/tester/commands/test_run.py new file mode 100644 index 000000000..d751e3ee6 --- /dev/null +++ b/tests/apps/tester/commands/test_run.py @@ -0,0 +1,75 @@ +import os +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from piccolo.apps.tester.commands.run import run, set_env_var + + +class TestSetEnvVar(TestCase): + def test_no_existing_value(self): + """ + Make sure the environment variable is set correctly, when there is + no existing value. + """ + var_name = "PICCOLO_TEST_1" + + # Make sure it definitely doesn't exist already + if os.environ.get(var_name) is not None: + del os.environ[var_name] + + new_value = "hello world" + + with set_env_var(var_name=var_name, temp_value=new_value): + self.assertEqual(os.environ.get(var_name), new_value) + + self.assertEqual(os.environ.get(var_name), None) + + def test_existing_value(self): + """ + Make sure the environment variable is set correctly, when there is + an existing value. + """ + var_name = "PICCOLO_TEST_2" + initial_value = "hello" + new_value = "goodbye" + + os.environ[var_name] = initial_value + + with set_env_var(var_name=var_name, temp_value=new_value): + self.assertEqual(os.environ.get(var_name), new_value) + + self.assertEqual(os.environ.get(var_name), initial_value) + + def test_raise_exception(self): + """ + Make sure the environment variable is still reset, even if an exception + is raised within the context manager body. + """ + var_name = "PICCOLO_TEST_3" + initial_value = "hello" + new_value = "goodbye" + + os.environ[var_name] = initial_value + + class FakeException(Exception): + pass + + try: + with set_env_var(var_name=var_name, temp_value=new_value): + self.assertEqual(os.environ.get(var_name), new_value) + raise FakeException("Something went wrong ...") + except FakeException: + pass + + self.assertEqual(os.environ.get(var_name), initial_value) + + +class TestRun(TestCase): + @patch("piccolo.apps.tester.commands.run.run_pytest") + @patch("piccolo.apps.tester.commands.run.refresh_db") + def test_success(self, refresh_db: MagicMock, pytest: MagicMock): + with self.assertRaises(SystemExit): + run(pytest_args="-s foo", piccolo_conf="my_piccolo_conf") + + pytest.assert_called_once_with(["-s", "foo"]) + refresh_db.assert_called_once() diff --git a/tests/apps/user/commands/test_change_permissions.py b/tests/apps/user/commands/test_change_permissions.py new file mode 100644 index 000000000..6d794a852 --- /dev/null +++ b/tests/apps/user/commands/test_change_permissions.py @@ -0,0 +1,66 @@ +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from piccolo.apps.user.commands.change_permissions import ( + Level, + change_permissions, +) +from piccolo.apps.user.tables import BaseUser +from piccolo.utils.sync import run_sync + + +class TestChangePassword(TestCase): + def setUp(self): + BaseUser.create_table(if_not_exists=True).run_sync() + + BaseUser( + username="bob", + password="bob123", + first_name="Bob", + last_name="Jones", + email="bob@gmail.com", + active=False, + admin=False, + superuser=False, + ).save().run_sync() + + def tearDown(self): + BaseUser.alter().drop_table().run_sync() + + @patch("piccolo.apps.user.commands.change_permissions.colored_string") + def test_user_doesnt_exist(self, colored_string: MagicMock): + run_sync(change_permissions(username="sally")) + colored_string.assert_called_once_with( + "User sally doesn't exist!", level=Level.medium + ) + + def test_admin(self): + run_sync(change_permissions(username="bob", admin=True)) + self.assertTrue( + BaseUser.exists() + .where(BaseUser.username == "bob", BaseUser.admin.eq(True)) + .run_sync() + ) + + def test_active(self): + run_sync(change_permissions(username="bob", active=True)) + self.assertTrue( + BaseUser.exists() + .where(BaseUser.username == "bob", BaseUser.active.eq(True)) + .run_sync() + ) + + def test_superuser(self): + run_sync(change_permissions(username="bob", superuser=True)) + self.assertTrue( + BaseUser.exists() + .where(BaseUser.username == "bob", BaseUser.superuser.eq(True)) + .run_sync() + ) + + @patch("piccolo.apps.user.commands.change_permissions.colored_string") + def test_no_params(self, colored_string): + run_sync(change_permissions(username="bob")) + colored_string.assert_called_once_with( + "No changes detected", level=Level.medium + ) diff --git a/tests/apps/user/commands/test_create.py b/tests/apps/user/commands/test_create.py index 5f78dab89..5c85816fd 100644 --- a/tests/apps/user/commands/test_create.py +++ b/tests/apps/user/commands/test_create.py @@ -49,8 +49,31 @@ def test_create(self, *args, **kwargs): (BaseUser.admin == True) # noqa: E712 & (BaseUser.username == "bob123") & (BaseUser.email == "bob@test.com") - & (BaseUser.superuser == True) - & (BaseUser.active == True) + & (BaseUser.superuser.eq(True)) + & (BaseUser.active.eq(True)) + ) + .run_sync() + ) + + def test_create_with_arguments(self, *args, **kwargs): + arguments = { + "username": "bob123", + "email": "bob@test.com", + "password": "password123", + "is_admin": True, + "is_superuser": True, + "is_active": True, + } + create(**arguments) + + self.assertTrue( + BaseUser.exists() + .where( + (BaseUser.admin == True) # noqa: E712 + & (BaseUser.username == "bob123") + & (BaseUser.email == "bob@test.com") + & (BaseUser.superuser.eq(True)) + & (BaseUser.active.eq(True)) ) .run_sync() ) diff --git a/tests/apps/user/commands/test_list.py b/tests/apps/user/commands/test_list.py new file mode 100644 index 000000000..2b35a4f65 --- /dev/null +++ b/tests/apps/user/commands/test_list.py @@ -0,0 +1,97 @@ +from unittest import TestCase +from unittest.mock import AsyncMock, MagicMock, patch + +from piccolo.apps.user.commands.list import list_users +from piccolo.apps.user.tables import BaseUser +from piccolo.utils.sync import run_sync + + +class TestList(TestCase): + def setUp(self): + BaseUser.create_table(if_not_exists=True).run_sync() + self.username = "test_user" + self.password = "abc123XYZ" + self.user = BaseUser.create_user_sync( + username=self.username, password=self.password + ) + + def tearDown(self): + BaseUser.alter().drop_table().run_sync() + + @patch("piccolo.utils.printing.print") + def test_list(self, print_mock: MagicMock): + """ + Make sure the user information is listed, excluding the password. + """ + run_sync(list_users()) + + output = "\n".join(i.args[0] for i in print_mock.call_args_list) + + assert self.username in output + assert self.password not in output + assert self.user.password not in output + + +class TestLimit(TestCase): + def test_non_positive(self): + """ + Make sure non-positive `limit` values are rejected. + """ + for value in (0, -1): + with self.assertRaises(ValueError): + run_sync(list_users(page=value)) + + +class TestPage(TestCase): + def test_non_positive(self): + """ + Make sure non-positive `page` values are rejected. + """ + for value in (0, -1): + with self.assertRaises(ValueError): + run_sync(list_users(limit=value)) + + +class TestOrder(TestCase): + @patch("piccolo.apps.user.commands.list.get_users") + def test_order(self, get_users: AsyncMock): + """ + Make sure valid column names are accepted. + """ + get_users.return_value = [] + run_sync(list_users(order_by="email")) + + self.assertDictEqual( + get_users.call_args.kwargs, + { + "order_by": BaseUser.email, + "ascending": True, + "limit": 20, + "page": 1, + }, + ) + + @patch("piccolo.apps.user.commands.list.get_users") + def test_descending(self, get_users: AsyncMock): + """ + Make sure a colume name prefixed with '-' works. + """ + get_users.return_value = [] + run_sync(list_users(order_by="-email")) + + self.assertDictEqual( + get_users.call_args.kwargs, + { + "order_by": BaseUser.email, + "ascending": False, + "limit": 20, + "page": 1, + }, + ) + + def test_unrecognised_column(self): + """ + Make sure invalid column names are rejected. + """ + with self.assertRaises(ValueError): + run_sync(list_users(order_by="abc123")) diff --git a/tests/apps/user/test_tables.py b/tests/apps/user/test_tables.py index 4e493f229..123bb5a66 100644 --- a/tests/apps/user/test_tables.py +++ b/tests/apps/user/test_tables.py @@ -1,5 +1,6 @@ -import asyncio +import secrets from unittest import TestCase +from unittest.mock import MagicMock, call, patch from piccolo.apps.user.tables import BaseUser @@ -23,9 +24,23 @@ def test_create_user_table(self): self.assertFalse(exception) -class TestHashPassword(TestCase): - def test_hash_password(self): - pass +class TestInstantiateUser(TestCase): + def setUp(self): + BaseUser.create_table().run_sync() + + def tearDown(self): + BaseUser.alter().drop_table().run_sync() + + def test_valid_credentials(self): + BaseUser(username="bob", password="abc123%£1pscl") + + def test_malicious_password(self): + malicious_password = secrets.token_urlsafe(1000) + with self.assertRaises(ValueError) as manager: + BaseUser(username="bob", password=malicious_password) + self.assertEqual( + manager.exception.__str__(), "The password is too long." + ) class TestLogin(TestCase): @@ -35,22 +50,41 @@ def setUp(self): def tearDown(self): BaseUser.alter().drop_table().run_sync() - def test_login(self): + @patch("piccolo.apps.user.tables.logger") + def test_login(self, logger: MagicMock): username = "bob" password = "Bob123$$$" email = "bob@bob.com" user = BaseUser(username=username, password=password, email=email) + user.save().run_sync() - save_query = user.save() + # Test correct password + authenticated = BaseUser.login_sync(username, password) + self.assertTrue(authenticated == user.id) - save_query.run_sync() + # Test incorrect password + authenticated = BaseUser.login_sync(username, "blablabla") + self.assertTrue(authenticated is None) - authenticated = asyncio.run(BaseUser.login(username, password)) - self.assertTrue(authenticated is not None) + # Test ultra long password + malicious_password = secrets.token_urlsafe(1000) + authenticated = BaseUser.login_sync(username, malicious_password) + self.assertTrue(authenticated is None) + self.assertEqual( + logger.method_calls, + [call.warning("Excessively long password provided.")], + ) - authenticated = asyncio.run(BaseUser.login(username, "blablabla")) - self.assertTrue(not authenticated) + # Test ulta long username + logger.reset_mock() + malicious_username = secrets.token_urlsafe(1000) + authenticated = BaseUser.login_sync(malicious_username, password) + self.assertTrue(authenticated is None) + self.assertEqual( + logger.method_calls, + [call.warning("Excessively long username provided.")], + ) def test_update_password(self): username = "bob" @@ -63,7 +97,211 @@ def test_update_password(self): authenticated = BaseUser.login_sync(username, password) self.assertTrue(authenticated is not None) + # Test success new_password = "XXX111" BaseUser.update_password_sync(username, new_password) authenticated = BaseUser.login_sync(username, new_password) self.assertTrue(authenticated is not None) + + # Test ultra long password + malicious_password = secrets.token_urlsafe(1000) + with self.assertRaises(ValueError) as manager: + BaseUser.update_password_sync(username, malicious_password) + self.assertEqual( + manager.exception.__str__(), + f"The password is too long. (max {BaseUser._max_password_length})", + ) + + # Test short passwords + short_password = "abc" + with self.assertRaises(ValueError) as manager: + BaseUser.update_password_sync(username, short_password) + self.assertEqual( + manager.exception.__str__(), + ( + "The password is too short. (min " + f"{BaseUser._min_password_length})" + ), + ) + + # Test no password + empty_password = "" + with self.assertRaises(ValueError) as manager: + BaseUser.update_password_sync(username, empty_password) + self.assertEqual( + manager.exception.__str__(), + "A password must be provided.", + ) + + # Test hashed password + hashed_password = "pbkdf2_sha256$abc123" + with self.assertRaises(ValueError) as manager: + BaseUser.update_password_sync(username, hashed_password) + self.assertEqual( + manager.exception.__str__(), + "Do not pass a hashed password.", + ) + + +class TestCreateUserFromFixture(TestCase): + def setUp(self): + BaseUser.create_table().run_sync() + + def tearDown(self): + BaseUser.alter().drop_table().run_sync() + + def test_create_user_from_fixture(self): + the_data = { + "id": 2, + "username": "", + "password": "pbkdf2_sha256$10000$19ed2c0d6cbe0868a70be6" + "446b93ed5b$c862974665ccc25b334ed42fa7e96a41" + "04d5ddff0c2e56e0e5b1d0efc67e9d03", + "first_name": "", + "last_name": "", + "email": "", + "active": False, + "admin": False, + "superuser": False, + "last_login": None, + } + user = BaseUser.from_dict(the_data) + self.assertIsInstance(user, BaseUser) + self.assertEqual(user.password, the_data["password"]) + + +class TestCreateUser(TestCase): + def setUp(self): + BaseUser.create_table().run_sync() + + def tearDown(self): + BaseUser.alter().drop_table().run_sync() + + def test_success(self): + user = BaseUser.create_user_sync(username="bob", password="abc123") + self.assertTrue(isinstance(user, BaseUser)) + self.assertEqual( + BaseUser.login_sync(username="bob", password="abc123"), user.id + ) + + @patch("piccolo.apps.user.tables.logger") + def test_hashed_password_error(self, logger: MagicMock): + with self.assertRaises(ValueError) as manager: + BaseUser.create_user_sync( + username="bob", password="pbkdf2_sha256$10000" + ) + + self.assertEqual( + manager.exception.__str__(), "Do not pass a hashed password." + ) + self.assertEqual( + logger.method_calls, + [ + call.warning( + "Tried to create a user with an already hashed password." + ) + ], + ) + + def test_short_password_error(self): + with self.assertRaises(ValueError) as manager: + BaseUser.create_user_sync(username="bob", password="abc") + + self.assertEqual( + manager.exception.__str__(), + ( + "The password is too short. (min " + f"{BaseUser._min_password_length})" + ), + ) + + def test_long_password_error(self): + with self.assertRaises(ValueError) as manager: + BaseUser.create_user_sync( + username="bob", + password="x" * (BaseUser._max_password_length + 1), + ) + + self.assertEqual( + manager.exception.__str__(), + f"The password is too long. (max {BaseUser._max_password_length})", + ) + + def test_no_username_error(self): + with self.assertRaises(ValueError) as manager: + BaseUser.create_user_sync( + username=None, # type: ignore + password="abc123", + ) + + self.assertEqual( + manager.exception.__str__(), "A username must be provided." + ) + + def test_no_password_error(self): + with self.assertRaises(ValueError) as manager: + BaseUser.create_user_sync( + username="bob", + password=None, # type: ignore + ) + + self.assertEqual( + manager.exception.__str__(), "A password must be provided." + ) + + +class TestAutoHashingUpdate(TestCase): + """ + Make sure that users with passwords which were hashed in earlier Piccolo + versions are automatically re-hashed, meeting current best practices with + the number of hashing iterations. + """ + + def setUp(self): + BaseUser.create_table().run_sync() + + def tearDown(self): + BaseUser.alter().drop_table().run_sync() + + def test_hash_update(self): + # Create a user + username = "bob" + password = "abc123" + user = BaseUser.create_user_sync(username=username, password=password) + + # Update their password, so it uses less than the recommended number + # of hashing iterations. + BaseUser.update( + { + BaseUser.password: BaseUser.hash_password( + password=password, + iterations=int(BaseUser._pbkdf2_iteration_count / 2), + ) + } + ).where(BaseUser.id == user.id).run_sync() + + # Login the user - Piccolo should detect their password needs rehashing + # and update it. + self.assertIsNotNone( + BaseUser.login_sync(username=username, password=password) + ) + + user_data = ( + BaseUser.select(BaseUser.password) + .where(BaseUser.id == user.id) + .first() + .run_sync() + ) + assert user_data is not None + hashed_password = user_data["password"] + + algorithm, iterations_, salt, hashed = BaseUser.split_stored_password( + hashed_password + ) + + self.assertEqual(int(iterations_), BaseUser._pbkdf2_iteration_count) + + # Make sure subsequent logins work as expected + self.assertIsNotNone( + BaseUser.login_sync(username=username, password=password) + ) diff --git a/tests/base.py b/tests/base.py index bfe566d24..4651a4a04 100644 --- a/tests/base.py +++ b/tests/base.py @@ -1,43 +1,161 @@ from __future__ import annotations -import typing as t +import asyncio +import sys +from typing import Optional from unittest import TestCase from unittest.mock import MagicMock import pytest +from piccolo.apps.schema.commands.generate import RowMeta +from piccolo.engine.cockroach import CockroachEngine from piccolo.engine.finder import engine_finder from piccolo.engine.postgres import PostgresEngine from piccolo.engine.sqlite import SQLiteEngine -from piccolo.table import Table, create_table_class +from piccolo.table import ( + Table, + create_db_tables_sync, + create_table_class, + drop_db_tables_sync, +) +from piccolo.utils.sync import run_sync ENGINE = engine_finder() +def engine_version_lt(version: float) -> bool: + return ENGINE is not None and run_sync(ENGINE.get_version()) < version + + +def is_running_postgres() -> bool: + return type(ENGINE) is PostgresEngine + + +def is_running_sqlite() -> bool: + return type(ENGINE) is SQLiteEngine + + +def is_running_cockroach() -> bool: + return type(ENGINE) is CockroachEngine + + postgres_only = pytest.mark.skipif( - not isinstance(ENGINE, PostgresEngine), reason="Only running for Postgres" + not is_running_postgres(), reason="Only running for Postgres" ) - sqlite_only = pytest.mark.skipif( - not isinstance(ENGINE, SQLiteEngine), reason="Only running for SQLite" + not is_running_sqlite(), reason="Only running for SQLite" +) + +cockroach_only = pytest.mark.skipif( + not is_running_cockroach(), reason="Only running for Cockroach" +) + +unix_only = pytest.mark.skipif( + sys.platform.startswith("win"), reason="Only running on a Unix system" ) -def set_mock_return_value(magic_mock: MagicMock, return_value: t.Any): +def engines_only(*engine_names: str): """ - Python 3.8 has good support for mocking coroutines. For older versions, - we must set the return value to be an awaitable explicitly. + Test decorator. Choose what engines can run a test. + + For example:: + + @engines_only('cockroach', 'postgres') + def test_unknown_column_type(...): + self.assertTrue(...) + """ - if magic_mock.__class__.__name__ == "AsyncMock": - # Python 3.8 and above - magic_mock.return_value = return_value + if ENGINE: + current_engine_name = ENGINE.engine_type + if current_engine_name not in engine_names: + + def wrapper(func): + return pytest.mark.skip( + f"Not running for {current_engine_name}" + )(func) + + return wrapper + else: + + def wrapper(func): + return func + + return wrapper else: + raise ValueError("Engine not found") + + +def engines_skip(*engine_names: str): + """ + Test decorator. Choose what engines can run a test. + + For example:: + + @engines_skip('cockroach', 'postgres') + def test_unknown_column_type(...): + self.assertTrue(...) + + """ + if ENGINE: + current_engine_name = ENGINE.engine_type + if current_engine_name in engine_names: + + def wrapper(func): + return pytest.mark.skip( + f"Not yet available for {current_engine_name}" + )(func) + + return wrapper + else: + + def wrapper(func): + return func + + return wrapper + else: + raise ValueError("Engine not found") + + +def engine_is(*engine_names: str): + """ + Assert branching. Choose what engines can run an assert. + If branching becomes too complex, make a new test with + @engines_only() or engines_skip() + + Example + def test_unknown_column_type(...): + if engine_is('cockroach', 'sqlite'): + self.assertTrue(...) + """ + if ENGINE: + current_engine_name = ENGINE.engine_type + if current_engine_name not in engine_names: + return False + else: + return True + else: + raise ValueError("Engine not found") + + +class AsyncMock(MagicMock): + """ + Async MagicMock for python 3.7+. + + This is a workaround for the fact that MagicMock is not async compatible in + Python 3.7. + """ - async def coroutine(*args, **kwargs): - return return_value + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - magic_mock.return_value = coroutine() + # this makes asyncio.iscoroutinefunction(AsyncMock()) return True + self._is_coroutine = asyncio.coroutines._is_coroutine + + async def __call__(self, *args, **kwargs): + return super(AsyncMock, self).__call__(*args, **kwargs) class DBTestCase(TestCase): @@ -51,7 +169,7 @@ def run_sync(self, query): return _Table.raw(query).run_sync() def table_exists(self, tablename: str) -> bool: - _Table: t.Type[Table] = create_table_class( + _Table: type[Table] = create_table_class( class_name=tablename.upper(), class_kwargs={"tablename": tablename} ) return _Table.table_exists().run_sync() @@ -61,18 +179,25 @@ def table_exists(self, tablename: str) -> bool: # Postgres specific utils def get_postgres_column_definition( - self, tablename: str, column_name: str - ) -> t.Dict[str, t.Any]: + self, tablename: str, column_name: str, schema: str = "public" + ) -> RowMeta: query = """ - SELECT * FROM information_schema.columns + SELECT {columns} FROM information_schema.columns WHERE table_name = '{tablename}' AND table_catalog = 'piccolo' + AND table_schema = '{schema}' AND column_name = '{column_name}' """.format( - tablename=tablename, column_name=column_name + columns=RowMeta.get_column_name_str(), + tablename=tablename, + schema=schema, + column_name=column_name, ) response = self.run_sync(query) - return response[0] + if len(response) > 0: + return RowMeta(**response[0]) + else: + raise ValueError("No such column") def get_postgres_column_type( self, tablename: str, column_name: str @@ -82,7 +207,7 @@ def get_postgres_column_type( """ return self.get_postgres_column_definition( tablename=tablename, column_name=column_name - )["data_type"].upper() + ).data_type.upper() def get_postgres_is_nullable(self, tablename, column_name: str) -> bool: """ @@ -91,22 +216,26 @@ def get_postgres_is_nullable(self, tablename, column_name: str) -> bool: return ( self.get_postgres_column_definition( tablename=tablename, column_name=column_name - )["is_nullable"].upper() + ).is_nullable.upper() == "YES" ) - def get_postgres_varchar_length(self, tablename, column_name: str) -> int: + def get_postgres_varchar_length( + self, tablename, column_name: str + ) -> Optional[int]: """ Fetches whether the column is defined as nullable, from the database. """ return self.get_postgres_column_definition( tablename=tablename, column_name=column_name - )["character_maximum_length"] + ).character_maximum_length ########################################################################### def create_tables(self): - if ENGINE.engine_type == "postgres": + assert ENGINE is not None + + if ENGINE.engine_type in ("postgres", "cockroach"): self.run_sync( """ CREATE TABLE manager ( @@ -186,60 +315,120 @@ def create_tables(self): raise Exception("Unrecognised engine") def insert_row(self): - self.run_sync( - """ - INSERT INTO manager ( - name - ) VALUES ( - 'Guido' - );""" - ) - self.run_sync( - """ - INSERT INTO band ( - name, - manager, - popularity - ) VALUES ( - 'Pythonistas', - 1, - 1000 - );""" - ) + assert ENGINE is not None + + if ENGINE.engine_type == "cockroach": + id = self.run_sync( + """ + INSERT INTO manager ( + name + ) VALUES ( + 'Guido' + ) RETURNING id;""" + ) + self.run_sync( + f""" + INSERT INTO band ( + name, + manager, + popularity + ) VALUES ( + 'Pythonistas', + {id[0]["id"]}, + 1000 + );""" + ) + else: + self.run_sync( + """ + INSERT INTO manager ( + name + ) VALUES ( + 'Guido' + );""" + ) + self.run_sync( + """ + INSERT INTO band ( + name, + manager, + popularity + ) VALUES ( + 'Pythonistas', + 1, + 1000 + );""" + ) def insert_rows(self): - self.run_sync( - """ - INSERT INTO manager ( - name - ) VALUES ( - 'Guido' - ),( - 'Graydon' - ),( - 'Mads' - );""" - ) - self.run_sync( - """ - INSERT INTO band ( - name, - manager, - popularity - ) VALUES ( - 'Pythonistas', - 1, - 1000 - ),( - 'Rustaceans', - 2, - 2000 - ),( - 'CSharps', - 3, - 10 - );""" - ) + assert ENGINE is not None + + if ENGINE.engine_type == "cockroach": + id = self.run_sync( + """ + INSERT INTO manager ( + name + ) VALUES ( + 'Guido' + ),( + 'Graydon' + ),( + 'Mads' + ) RETURNING id;""" + ) + self.run_sync( + f""" + INSERT INTO band ( + name, + manager, + popularity + ) VALUES ( + 'Pythonistas', + {id[0]["id"]}, + 1000 + ),( + 'Rustaceans', + {id[1]["id"]}, + 2000 + ),( + 'CSharps', + {id[2]["id"]}, + 10 + );""" + ) + else: + self.run_sync( + """ + INSERT INTO manager ( + name + ) VALUES ( + 'Guido' + ),( + 'Graydon' + ),( + 'Mads' + );""" + ) + self.run_sync( + """ + INSERT INTO band ( + name, + manager, + popularity + ) VALUES ( + 'Pythonistas', + 1, + 1000 + ),( + 'Rustaceans', + 2, + 2000 + ),( + 'CSharps', + 3, + 10 + );""" + ) def insert_many_rows(self, row_count=10000): """ @@ -250,7 +439,9 @@ def insert_many_rows(self, row_count=10000): self.run_sync(f"INSERT INTO manager (name) VALUES {values_string};") def drop_tables(self): - if ENGINE.engine_type == "postgres": + assert ENGINE is not None + + if ENGINE.engine_type in ("postgres", "cockroach"): self.run_sync("DROP TABLE IF EXISTS band CASCADE;") self.run_sync("DROP TABLE IF EXISTS manager CASCADE;") self.run_sync("DROP TABLE IF EXISTS ticket CASCADE;") @@ -268,3 +459,17 @@ def setUp(self): def tearDown(self): self.drop_tables() + + +class TableTest(TestCase): + """ + Used for tests where we need to create Piccolo tables. + """ + + tables: list[type[Table]] + + def setUp(self) -> None: + create_db_tables_sync(*self.tables) + + def tearDown(self) -> None: + drop_db_tables_sync(*self.tables) diff --git a/tests/cockroach_conf.py b/tests/cockroach_conf.py new file mode 100644 index 000000000..11b9bf651 --- /dev/null +++ b/tests/cockroach_conf.py @@ -0,0 +1,22 @@ +import os + +from piccolo.conf.apps import AppRegistry +from piccolo.engine.cockroach import CockroachEngine + +DB = CockroachEngine( + config={ + "host": os.environ.get("PG_HOST", "localhost"), + "port": os.environ.get("PG_PORT", "26257"), + "user": os.environ.get("PG_USER", "root"), + "password": os.environ.get("PG_PASSWORD", ""), + "database": os.environ.get("PG_DATABASE", "piccolo"), + } +) + + +APP_REGISTRY = AppRegistry( + apps=[ + "tests.example_apps.music.piccolo_app", + "tests.example_apps.mega.piccolo_app", + ] +) diff --git a/tests/columns/foreign_key/test_all_columns.py b/tests/columns/foreign_key/test_all_columns.py new file mode 100644 index 000000000..0d6828ddf --- /dev/null +++ b/tests/columns/foreign_key/test_all_columns.py @@ -0,0 +1,51 @@ +from unittest import TestCase + +from tests.example_apps.music.tables import Band, Concert + + +class TestAllColumns(TestCase): + def test_all_columns(self): + """ + Make sure you can retrieve all columns from a related table, without + explicitly specifying them. + """ + all_columns = Band.manager.all_columns() + self.assertEqual(all_columns, [Band.manager.id, Band.manager.name]) + + # Make sure the call chains are also correct. + self.assertEqual( + all_columns[0]._meta.call_chain, Band.manager.id._meta.call_chain + ) + self.assertEqual( + all_columns[1]._meta.call_chain, Band.manager.name._meta.call_chain + ) + + def test_all_columns_deep(self): + """ + Make sure ``all_columns`` works when the joins are several layers deep. + """ + all_columns = Concert.band_1._.manager.all_columns() + self.assertEqual(all_columns, [Band.manager._.id, Band.manager._.name]) + + # Make sure the call chains are also correct. + self.assertEqual( + all_columns[0]._meta.call_chain, + Concert.band_1._.manager._.id._meta.call_chain, + ) + self.assertEqual( + all_columns[1]._meta.call_chain, + Concert.band_1._.manager._.name._meta.call_chain, + ) + + def test_all_columns_exclude(self): + """ + Make sure you can exclude some columns. + """ + self.assertEqual( + Band.manager.all_columns(exclude=["id"]), [Band.manager.name] + ) + + self.assertEqual( + Band.manager.all_columns(exclude=[Band.manager.id]), + [Band.manager.name], + ) diff --git a/tests/columns/foreign_key/test_all_related.py b/tests/columns/foreign_key/test_all_related.py new file mode 100644 index 000000000..94ebf7dc2 --- /dev/null +++ b/tests/columns/foreign_key/test_all_related.py @@ -0,0 +1,62 @@ +from unittest import TestCase + +from tests.example_apps.music.tables import Ticket + + +class TestAllRelated(TestCase): + def test_all_related(self): + """ + Make sure you can retrieve all foreign keys from a related table, + without explicitly specifying them. + """ + all_related = Ticket.concert.all_related() + + self.assertEqual( + all_related, + [ + Ticket.concert.band_1, + Ticket.concert.band_2, + Ticket.concert.venue, + ], + ) + + # Make sure the call chains are also correct. + self.assertEqual( + all_related[0]._meta.call_chain, + Ticket.concert.band_1._meta.call_chain, + ) + self.assertEqual( + all_related[1]._meta.call_chain, + Ticket.concert.band_2._meta.call_chain, + ) + self.assertEqual( + all_related[2]._meta.call_chain, + Ticket.concert.venue._meta.call_chain, + ) + + def test_all_related_deep(self): + """ + Make sure ``all_related`` works when the joins are several layers deep. + """ + all_related = Ticket.concert._.band_1.all_related() + self.assertEqual(all_related, [Ticket.concert._.band_1._.manager]) + + # Make sure the call chains are also correct. + self.assertEqual( + all_related[0]._meta.call_chain, + Ticket.concert._.band_1._.manager._meta.call_chain, + ) + + def test_all_related_exclude(self): + """ + Make sure you can exclude some columns. + """ + self.assertEqual( + Ticket.concert.all_related(exclude=["venue"]), + [Ticket.concert.band_1, Ticket.concert.band_2], + ) + + self.assertEqual( + Ticket.concert.all_related(exclude=[Ticket.concert._.venue]), + [Ticket.concert.band_1, Ticket.concert.band_2], + ) diff --git a/tests/columns/foreign_key/test_attribute_access.py b/tests/columns/foreign_key/test_attribute_access.py new file mode 100644 index 000000000..597b33bd6 --- /dev/null +++ b/tests/columns/foreign_key/test_attribute_access.py @@ -0,0 +1,65 @@ +import time +from unittest import TestCase + +from piccolo.columns import Column, ForeignKey, LazyTableReference, Varchar +from piccolo.table import Table + + +class Manager(Table): + name = Varchar() + manager: ForeignKey["Manager"] = ForeignKey("self") + + +class BandA(Table): + manager = ForeignKey(references=Manager) + + +class BandB(Table): + manager: ForeignKey["Manager"] = ForeignKey(references="Manager") + + +class BandC(Table): + manager: ForeignKey["Manager"] = ForeignKey( + references=LazyTableReference( + table_class_name="Manager", + module_path=__name__, + ) + ) + + +class BandD(Table): + manager: ForeignKey["Manager"] = ForeignKey( + references=f"{__name__}.Manager" + ) + + +class TestAttributeAccess(TestCase): + def test_attribute_access(self): + """ + Make sure that attribute access still works correctly with lazy + references. + """ + for band_table in (BandA, BandB, BandC, BandD): + self.assertIsInstance(band_table.manager.name, Varchar) + + def test_recursion_limit(self) -> None: + """ + When a table has a ForeignKey to itself, an Exception should be raised + if the call chain is too large. + """ + # Should be fine: + column: Column = Manager.manager.name + self.assertTrue(len(column._meta.call_chain), 1) + self.assertIsInstance(column, Varchar) + + with self.assertRaises(Exception): + Manager.manager.manager.manager.manager.manager.manager.manager.manager.manager.manager.manager.name # type: ignore # noqa: E501 + + def test_recursion_time(self): + """ + Make sure that a really large call chain doesn't take too long. + """ + start = time.time() + Manager.manager._.manager._.manager._.manager._.manager._.manager._.name # noqa: E501 + end = time.time() + self.assertLess(end - start, 1.0) diff --git a/tests/columns/foreign_key/test_column_type.py b/tests/columns/foreign_key/test_column_type.py new file mode 100644 index 000000000..68f3d6b00 --- /dev/null +++ b/tests/columns/foreign_key/test_column_type.py @@ -0,0 +1,69 @@ +from unittest import TestCase + +from piccolo.columns import ( + UUID, + BigInt, + BigSerial, + ForeignKey, + Integer, + Serial, + Varchar, +) +from piccolo.table import Table + + +class TestColumnType(TestCase): + """ + The `column_type` of the `ForeignKey` should depend on the `PrimaryKey` of + the referenced table. + """ + + def test_serial(self): + class Manager(Table): + id = Serial(primary_key=True) + + class Band(Table): + manager = ForeignKey(Manager) + + self.assertEqual( + Band.manager.column_type, + Integer().column_type, + ) + + def test_bigserial(self): + class Manager(Table): + id = BigSerial(primary_key=True) + + class Band(Table): + manager = ForeignKey(Manager) + + self.assertEqual( + Band.manager.column_type, + BigInt()._get_column_type( + engine_type=Band.manager._meta.engine_type + ), + ) + + def test_uuid(self): + class Manager(Table): + id = UUID(primary_key=True) + + class Band(Table): + manager = ForeignKey(Manager) + + self.assertEqual( + Band.manager.column_type, + Manager.id.column_type, + ) + + def test_varchar(self): + class Manager(Table): + id = Varchar(primary_key=True) + + class Band(Table): + manager = ForeignKey(Manager) + + self.assertEqual( + Band.manager.column_type, + Manager.id.column_type, + ) diff --git a/tests/columns/foreign_key/test_foreign_key_meta.py b/tests/columns/foreign_key/test_foreign_key_meta.py new file mode 100644 index 000000000..78560fab8 --- /dev/null +++ b/tests/columns/foreign_key/test_foreign_key_meta.py @@ -0,0 +1,36 @@ +from unittest import TestCase + +from piccolo.columns import ForeignKey, Varchar +from piccolo.columns.base import OnDelete, OnUpdate +from piccolo.table import Table + + +class Manager(Table): + name = Varchar() + + +class Band(Table): + """ + Contains a ForeignKey with non-default `on_delete` and `on_update` values. + """ + + manager = ForeignKey( + references=Manager, + on_delete=OnDelete.set_null, + on_update=OnUpdate.set_null, + ) + + +class TestForeignKeyMeta(TestCase): + """ + Make sure that `ForeignKeyMeta` is setup correctly. + """ + + def test_foreignkeymeta(self): + self.assertTrue( + Band.manager._foreign_key_meta.on_update == OnUpdate.set_null + ) + self.assertTrue( + Band.manager._foreign_key_meta.on_delete == OnDelete.set_null + ) + self.assertTrue(Band.manager._foreign_key_meta.references == Manager) diff --git a/tests/columns/foreign_key/test_foreign_key_references.py b/tests/columns/foreign_key/test_foreign_key_references.py new file mode 100644 index 000000000..76b5f2f39 --- /dev/null +++ b/tests/columns/foreign_key/test_foreign_key_references.py @@ -0,0 +1,28 @@ +from unittest import TestCase + +from piccolo.columns import ForeignKey, Varchar +from piccolo.table import Table + + +class Manager(Table, tablename="manager_fk_references_test"): + name = Varchar() + + +class BandA(Table): + manager = ForeignKey(references=Manager) + + +class BandB(Table): + manager: ForeignKey["Manager"] = ForeignKey(references="Manager") + + +class TestReferences(TestCase): + def test_foreign_key_references(self): + """ + Make sure foreign key references are stored correctly on the table + which is the target of the ForeignKey. + """ + self.assertEqual(len(Manager._meta.foreign_key_references), 2) + + self.assertTrue(BandA.manager in Manager._meta.foreign_key_references) + self.assertTrue(BandB.manager in Manager._meta.foreign_key_references) diff --git a/tests/columns/foreign_key/test_foreign_key_self.py b/tests/columns/foreign_key/test_foreign_key_self.py new file mode 100644 index 000000000..18c35e337 --- /dev/null +++ b/tests/columns/foreign_key/test_foreign_key_self.py @@ -0,0 +1,43 @@ +from unittest import TestCase + +from piccolo.columns import ForeignKey, Serial, Varchar +from piccolo.table import Table + + +class Manager(Table, tablename="manager"): + id: Serial + name = Varchar() + manager: ForeignKey["Manager"] = ForeignKey("self", null=True) + + +class TestForeignKeySelf(TestCase): + """ + Test that ForeignKey columns can be created with references to the parent + table. + """ + + def setUp(self): + Manager.create_table().run_sync() + + def tearDown(self): + Manager.alter().drop_table().run_sync() + + def test_foreign_key_self(self): + manager = Manager(name="Mr Manager") + manager.save().run_sync() + + worker = Manager(name="Mr Worker", manager=manager.id) + worker.save().run_sync() + + response = ( + Manager.select(Manager.name, Manager.manager.name) + .order_by(Manager.name) + .run_sync() + ) + self.assertEqual( + response, + [ + {"name": "Mr Manager", "manager.name": None}, + {"name": "Mr Worker", "manager.name": "Mr Manager"}, + ], + ) diff --git a/tests/columns/foreign_key/test_foreign_key_string.py b/tests/columns/foreign_key/test_foreign_key_string.py new file mode 100644 index 000000000..e37298734 --- /dev/null +++ b/tests/columns/foreign_key/test_foreign_key_string.py @@ -0,0 +1,71 @@ +from unittest import TestCase + +from piccolo.columns import ForeignKey, LazyTableReference, Varchar +from piccolo.table import Table + + +class Manager(Table): + name = Varchar() + + +class BandA(Table): + manager: ForeignKey["Manager"] = ForeignKey(references="Manager") + + +class BandB(Table): + manager: ForeignKey["Manager"] = ForeignKey( + references=LazyTableReference( + table_class_name="Manager", + module_path=__name__, + ) + ) + + +class BandC(Table, tablename="band"): + manager: ForeignKey["Manager"] = ForeignKey( + references=f"{__name__}.Manager" + ) + + +class TestForeignKeyString(TestCase): + """ + Test that ForeignKey columns can be created with a `references` argument + set as a string value. + """ + + def test_foreign_key_string(self): + for band_table in (BandA, BandB, BandC): + self.assertIs( + band_table.manager._foreign_key_meta.resolved_references, + Manager, + ) + + +class TestForeignKeyRelativeError(TestCase): + def test_foreign_key_relative_error(self): + """ + Make sure that a references argument which contains a relative module + isn't allowed. + """ + with self.assertRaises(ValueError) as manager: + + class BandRelative(Table, tablename="band"): + manager = ForeignKey("..example_app.tables.Manager", null=True) + + self.assertEqual( + manager.exception.__str__(), "Relative imports aren't allowed" + ) + + +class TestLazyTableReference(TestCase): + def test_lazy_reference_to_app(self): + """ + Make sure a LazyTableReference to a Table within a Piccolo app works. + """ + from tests.example_apps.music.tables import Manager + + reference = LazyTableReference( + table_class_name="Manager", app_name="music" + ) + + self.assertIs(reference.resolve(), Manager) diff --git a/tests/columns/foreign_key/test_on_delete_on_update.py b/tests/columns/foreign_key/test_on_delete_on_update.py new file mode 100644 index 000000000..c7356e171 --- /dev/null +++ b/tests/columns/foreign_key/test_on_delete_on_update.py @@ -0,0 +1,37 @@ +from piccolo.columns import ForeignKey, Varchar +from piccolo.columns.base import OnDelete, OnUpdate +from piccolo.query.constraints import get_fk_constraint_rules +from piccolo.table import Table +from piccolo.testing.test_case import AsyncTableTest +from tests.base import engines_only + + +class Manager(Table): + name = Varchar() + + +class Band(Table): + """ + Contains a ForeignKey with non-default `on_delete` and `on_update` values. + """ + + manager = ForeignKey( + references=Manager, + on_delete=OnDelete.set_null, + on_update=OnUpdate.set_null, + ) + + +@engines_only("postgres", "cockroach") +class TestOnDeleteOnUpdate(AsyncTableTest): + """ + Make sure that on_delete, and on_update are correctly applied in the + database. + """ + + tables = [Manager, Band] + + async def test_on_delete_on_update(self): + constraint_rules = await get_fk_constraint_rules(Band.manager) + self.assertEqual(constraint_rules.on_delete, OnDelete.set_null) + self.assertEqual(constraint_rules.on_update, OnDelete.set_null) diff --git a/tests/columns/foreign_key/test_reverse.py b/tests/columns/foreign_key/test_reverse.py new file mode 100644 index 000000000..5bf490c09 --- /dev/null +++ b/tests/columns/foreign_key/test_reverse.py @@ -0,0 +1,56 @@ +from piccolo.columns import ForeignKey, Text, Varchar +from piccolo.table import Table +from piccolo.testing.test_case import TableTest + + +class Band(Table): + name = Varchar() + + +class FanClub(Table): + address = Text() + band = ForeignKey(Band, unique=True) + + +class Treasurer(Table): + name = Varchar() + fan_club = ForeignKey(FanClub, unique=True) + + +class TestReverse(TableTest): + tables = [Band, FanClub, Treasurer] + + def setUp(self): + super().setUp() + + band = Band({Band.name: "Pythonistas"}) + band.save().run_sync() + + fan_club = FanClub( + {FanClub.band: band, FanClub.address: "1 Flying Circus, UK"} + ) + fan_club.save().run_sync() + + treasurer = Treasurer( + {Treasurer.fan_club: fan_club, Treasurer.name: "Bob"} + ) + treasurer.save().run_sync() + + def test_reverse(self): + response = Band.select( + Band.name, + FanClub.band.reverse().address.as_alias("address"), + Treasurer.fan_club._.band.reverse().name.as_alias( + "treasurer_name" + ), + ).run_sync() + self.assertListEqual( + response, + [ + { + "name": "Pythonistas", + "address": "1 Flying Circus, UK", + "treasurer_name": "Bob", + } + ], + ) diff --git a/tests/columns/foreign_key/test_schema.py b/tests/columns/foreign_key/test_schema.py new file mode 100644 index 000000000..7e6b45c18 --- /dev/null +++ b/tests/columns/foreign_key/test_schema.py @@ -0,0 +1,103 @@ +import datetime +from unittest import TestCase + +from piccolo.columns import Date, ForeignKey, Varchar +from piccolo.schema import SchemaManager +from piccolo.table import Table, create_db_tables_sync +from tests.base import engines_only + + +class Manager(Table, schema="schema_1"): + name = Varchar(length=50) + + +class Band(Table, schema="schema_1"): + name = Varchar(length=50) + manager = ForeignKey(Manager) + + +class Concert(Table, schema="schema_1"): + start_date = Date() + band = ForeignKey(Band) + + +TABLES = [Band, Manager, Concert] + + +@engines_only("postgres", "cockroach") +class TestForeignKeyWithSchema(TestCase): + """ + Make sure that foreign keys work with Postgres schemas. + """ + + schema_manager = SchemaManager() + schema_name = "schema_1" + + def setUp(self) -> None: + self.schema_manager.create_schema( + schema_name=self.schema_name + ).run_sync() + create_db_tables_sync(*TABLES) + + def tearDown(self) -> None: + self.schema_manager.drop_schema( + schema_name=self.schema_name, if_exists=True, cascade=True + ).run_sync() + + def test_with_schema(self): + """ + Make sure that foreign keys work with schemas. + """ + manager = Manager({Manager.name: "Guido"}) + manager.save().run_sync() + + band = Band({Band.manager: manager, Band.name: "Pythonistas"}) + band.save().run_sync() + + concert = Concert( + { + Concert.band: band, + Concert.start_date: datetime.date(year=2023, month=1, day=1), + } + ) + concert.save().run_sync() + + ####################################################################### + # Test single level join. + + query = Band.select( + Band.name, + Band.manager.name.as_alias("manager_name"), + ) + self.assertIn('"schema_1"."band"', query.__str__()) + self.assertIn('"schema_1"."manager"', query.__str__()) + + response = query.run_sync() + self.assertListEqual( + response, + [{"name": "Pythonistas", "manager_name": "Guido"}], + ) + + ####################################################################### + # Test two level join. + + query = Concert.select( + Concert.start_date, + Concert.band.name.as_alias("band_name"), + Concert.band._.manager._.name.as_alias("manager_name"), + ) + self.assertIn('"schema_1"."concert"', query.__str__()) + self.assertIn('"schema_1"."band"', query.__str__()) + self.assertIn('"schema_1"."manager"', query.__str__()) + + response = query.run_sync() + self.assertListEqual( + response, + [ + { + "start_date": datetime.date(2023, 1, 1), + "band_name": "Pythonistas", + "manager_name": "Guido", + } + ], + ) diff --git a/tests/columns/foreign_key/test_target_column.py b/tests/columns/foreign_key/test_target_column.py new file mode 100644 index 000000000..e9a0c4460 --- /dev/null +++ b/tests/columns/foreign_key/test_target_column.py @@ -0,0 +1,87 @@ +from unittest import TestCase + +from piccolo.columns import ForeignKey, Varchar +from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync + + +class Manager(Table): + name = Varchar(unique=True) + + +class Band(Table): + name = Varchar() + manager = ForeignKey(Manager, target_column="name") + + +class TestTargetColumnWithString(TestCase): + """ + Make sure we can create tables with foreign keys which don't reference + the primary key. + """ + + def setUp(self): + create_db_tables_sync(Manager, Band) + + def tearDown(self): + drop_db_tables_sync(Manager, Band) + + def test_queries(self): + manager_1 = Manager.objects().create(name="Guido").run_sync() + manager_2 = Manager.objects().create(name="Graydon").run_sync() + + Band.insert( + Band(name="Pythonistas", manager=manager_1), + Band(name="Rustaceans", manager=manager_2), + ).run_sync() + + response = Band.select(Band.name, Band.manager.name).run_sync() + self.assertEqual( + response, + [ + {"name": "Pythonistas", "manager.name": "Guido"}, + {"name": "Rustaceans", "manager.name": "Graydon"}, + ], + ) + + +############################################################################### + + +class ManagerA(Table): + name = Varchar(unique=True) + + +class BandA(Table): + name = Varchar() + manager = ForeignKey(ManagerA, target_column=ManagerA.name) + + +class TestTargetColumnWithColumnRef(TestCase): + """ + Make sure we can create tables with foreign keys which don't reference + the primary key. + """ + + def setUp(self): + create_db_tables_sync(ManagerA, BandA) + + def tearDown(self): + drop_db_tables_sync(ManagerA, BandA) + + def test_queries(self): + manager_1 = ManagerA.objects().create(name="Guido").run_sync() + manager_2 = ManagerA.objects().create(name="Graydon").run_sync() + + BandA.insert( + BandA(name="Pythonistas", manager=manager_1), + BandA(name="Rustaceans", manager=manager_2), + ).run_sync() + + response = BandA.select(BandA.name, BandA.manager.name).run_sync() + self.assertEqual( + response, + [ + {"name": "Pythonistas", "manager.name": "Guido"}, + {"name": "Rustaceans", "manager.name": "Graydon"}, + ], + ) diff --git a/tests/columns/foreign_key/test_value_type.py b/tests/columns/foreign_key/test_value_type.py new file mode 100644 index 000000000..1c858d425 --- /dev/null +++ b/tests/columns/foreign_key/test_value_type.py @@ -0,0 +1,33 @@ +import uuid +from unittest import TestCase + +from piccolo.columns import UUID, ForeignKey, Varchar +from piccolo.table import Table + + +class Manager(Table): + name = Varchar() + manager: ForeignKey["Manager"] = ForeignKey("self", null=True) + + +class Band(Table): + manager = ForeignKey(references=Manager) + + +class ManagerUUID(Table): + pk = UUID(primary_key=True) + + +class BandUUID(Table): + manager = ForeignKey(references=ManagerUUID) + + +class TestValueType(TestCase): + """ + The `value_type` of the `ForeignKey` should depend on the `PrimaryKey` of + the referenced table. + """ + + def test_value_type(self): + self.assertTrue(Band.manager.value_type is int) + self.assertTrue(BandUUID.manager.value_type is uuid.UUID) diff --git a/tests/columns/m2m/__init__.py b/tests/columns/m2m/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/columns/m2m/base.py b/tests/columns/m2m/base.py new file mode 100644 index 000000000..066ebab11 --- /dev/null +++ b/tests/columns/m2m/base.py @@ -0,0 +1,466 @@ +from typing import Optional + +from piccolo.columns.column_types import ( + ForeignKey, + LazyTableReference, + Serial, + Text, + Varchar, +) +from piccolo.columns.m2m import M2M +from piccolo.engine.finder import engine_finder +from piccolo.schema import SchemaManager +from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync +from tests.base import engines_skip + +engine = engine_finder() + + +class Band(Table): + id: Serial + name = Varchar() + genres = M2M(LazyTableReference("GenreToBand", module_path=__name__)) + + +class Genre(Table): + id: Serial + name = Varchar() + bands = M2M(LazyTableReference("GenreToBand", module_path=__name__)) + + +class GenreToBand(Table): + id: Serial + band = ForeignKey(Band) + genre = ForeignKey(Genre) + reason = Text(help_text="For testing additional columns on join tables.") + + +class M2MBase: + """ + This allows us to test M2M when the tables are in different schemas + (public vs non-public). + """ + + def _setUp(self, schema: Optional[str] = None): + self.schema = schema + + for table_class in (Band, Genre, GenreToBand): + table_class._meta.schema = schema + + self.all_tables = [Band, Genre, GenreToBand] + + create_db_tables_sync(*self.all_tables, if_not_exists=True) + + bands = Band.insert( + Band(name="Pythonistas"), + Band(name="Rustaceans"), + Band(name="C-Sharps"), + ).run_sync() + + genres = Genre.insert( + Genre(name="Rock"), + Genre(name="Folk"), + Genre(name="Classical"), + ).run_sync() + + GenreToBand.insert( + GenreToBand(band=bands[0]["id"], genre=genres[0]["id"]), + GenreToBand(band=bands[0]["id"], genre=genres[1]["id"]), + GenreToBand(band=bands[1]["id"], genre=genres[1]["id"]), + GenreToBand(band=bands[2]["id"], genre=genres[0]["id"]), + GenreToBand(band=bands[2]["id"], genre=genres[2]["id"]), + ).run_sync() + + def tearDown(self): + drop_db_tables_sync(*self.all_tables) + + if self.schema: + SchemaManager().drop_schema( + schema_name="schema_1", cascade=True + ).run_sync() + + def assertEqual(self, first, second, msg=None): + assert first == second + + def assertTrue(self, first, msg=None): + assert first is True + + @engines_skip("cockroach") + def test_select_name(self): + """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + """ # noqa: E501 + response = Band.select( + Band.name, Band.genres(Genre.name, as_list=True) + ).run_sync() + self.assertEqual( + response, + [ + {"name": "Pythonistas", "genres": ["Rock", "Folk"]}, + {"name": "Rustaceans", "genres": ["Folk"]}, + {"name": "C-Sharps", "genres": ["Rock", "Classical"]}, + ], + ) + + # Now try it in reverse. + response = Genre.select( + Genre.name, Genre.bands(Band.name, as_list=True) + ).run_sync() + self.assertEqual( + response, + [ + {"name": "Rock", "bands": ["Pythonistas", "C-Sharps"]}, + {"name": "Folk", "bands": ["Pythonistas", "Rustaceans"]}, + {"name": "Classical", "bands": ["C-Sharps"]}, + ], + ) + + @engines_skip("cockroach") + def test_no_related(self): + """ + Make sure it still works correctly if there are no related values. + """ + """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + """ # noqa: E501 + + GenreToBand.delete(force=True).run_sync() + + # Try it with a list response + response = Band.select( + Band.name, Band.genres(Genre.name, as_list=True) + ).run_sync() + + self.assertEqual( + response, + [ + {"name": "Pythonistas", "genres": []}, + {"name": "Rustaceans", "genres": []}, + {"name": "C-Sharps", "genres": []}, + ], + ) + + # Also try it with a nested response + response = Band.select( + Band.name, Band.genres(Genre.id, Genre.name) + ).run_sync() + self.assertEqual( + response, + [ + {"name": "Pythonistas", "genres": []}, + {"name": "Rustaceans", "genres": []}, + {"name": "C-Sharps", "genres": []}, + ], + ) + + @engines_skip("cockroach") + def test_select_multiple(self): + """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + """ # noqa: E501 + + response = Band.select( + Band.name, Band.genres(Genre.id, Genre.name) + ).run_sync() + + self.assertEqual( + response, + [ + { + "name": "Pythonistas", + "genres": [ + {"id": 1, "name": "Rock"}, + {"id": 2, "name": "Folk"}, + ], + }, + {"name": "Rustaceans", "genres": [{"id": 2, "name": "Folk"}]}, + { + "name": "C-Sharps", + "genres": [ + {"id": 1, "name": "Rock"}, + {"id": 3, "name": "Classical"}, + ], + }, + ], + ) + + # Now try it in reverse. + response = Genre.select( + Genre.name, Genre.bands(Band.id, Band.name) + ).run_sync() + + self.assertEqual( + response, + [ + { + "name": "Rock", + "bands": [ + {"id": 1, "name": "Pythonistas"}, + {"id": 3, "name": "C-Sharps"}, + ], + }, + { + "name": "Folk", + "bands": [ + {"id": 1, "name": "Pythonistas"}, + {"id": 2, "name": "Rustaceans"}, + ], + }, + { + "name": "Classical", + "bands": [{"id": 3, "name": "C-Sharps"}], + }, + ], + ) + + @engines_skip("cockroach") + def test_select_id(self): + """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + """ # noqa: E501 + + response = Band.select( + Band.name, Band.genres(Genre.id, as_list=True) + ).run_sync() + self.assertEqual( + response, + [ + {"name": "Pythonistas", "genres": [1, 2]}, + {"name": "Rustaceans", "genres": [2]}, + {"name": "C-Sharps", "genres": [1, 3]}, + ], + ) + + # Now try it in reverse. + response = Genre.select( + Genre.name, Genre.bands(Band.id, as_list=True) + ).run_sync() + self.assertEqual( + response, + [ + {"name": "Rock", "bands": [1, 3]}, + {"name": "Folk", "bands": [1, 2]}, + {"name": "Classical", "bands": [3]}, + ], + ) + + @engines_skip("cockroach") + def test_select_all_columns(self): + """ + Make sure ``all_columns`` can be passed in as an argument. ``M2M`` + should flatten the arguments. Reported here: + + https://github.com/piccolo-orm/piccolo/issues/728 + + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + + """ # noqa: E501 + + response = Band.select( + Band.name, Band.genres(Genre.all_columns(exclude=(Genre.id,))) + ).run_sync() + self.assertEqual( + response, + [ + { + "name": "Pythonistas", + "genres": [ + {"name": "Rock"}, + {"name": "Folk"}, + ], + }, + {"name": "Rustaceans", "genres": [{"name": "Folk"}]}, + { + "name": "C-Sharps", + "genres": [ + {"name": "Rock"}, + {"name": "Classical"}, + ], + }, + ], + ) + + def test_add_m2m(self): + """ + Make sure we can add items to the joining table. + """ + + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None + band.add_m2m(Genre(name="Punk Rock"), m2m=Band.genres).run_sync() + + self.assertTrue( + Genre.exists().where(Genre.name == "Punk Rock").run_sync() + ) + + self.assertEqual( + GenreToBand.count() + .where( + GenreToBand.band.name == "Pythonistas", + GenreToBand.genre.name == "Punk Rock", + ) + .run_sync(), + 1, + ) + + def test_extra_columns_str(self): + """ + Make sure the ``extra_column_values`` parameter for ``add_m2m`` works + correctly when the dictionary keys are strings. + """ + + reason = "Their second album was very punk rock." + + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None + band.add_m2m( + Genre(name="Punk Rock"), + m2m=Band.genres, + extra_column_values={ + "reason": "Their second album was very punk rock." + }, + ).run_sync() + + Genreto_band = ( + GenreToBand.objects() + .get( + (GenreToBand.band.name == "Pythonistas") + & (GenreToBand.genre.name == "Punk Rock") + ) + .run_sync() + ) + assert Genreto_band is not None + + self.assertEqual(Genreto_band.reason, reason) + + def test_extra_columns_class(self): + """ + Make sure the ``extra_column_values`` parameter for ``add_m2m`` works + correctly when the dictionary keys are ``Column`` classes. + """ + + reason = "Their second album was very punk rock." + + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None + band.add_m2m( + Genre(name="Punk Rock"), + m2m=Band.genres, + extra_column_values={ + GenreToBand.reason: "Their second album was very punk rock." + }, + ).run_sync() + + Genreto_band = ( + GenreToBand.objects() + .get( + (GenreToBand.band.name == "Pythonistas") + & (GenreToBand.genre.name == "Punk Rock") + ) + .run_sync() + ) + assert Genreto_band is not None + + self.assertEqual(Genreto_band.reason, reason) + + def test_add_m2m_existing(self): + """ + Make sure we can add an existing element to the joining table. + """ + + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None + + genre = Genre.objects().get(Genre.name == "Classical").run_sync() + assert genre is not None + + band.add_m2m(genre, m2m=Band.genres).run_sync() + + # We shouldn't have created a duplicate genre in the database. + self.assertEqual( + Genre.count().where(Genre.name == "Classical").run_sync(), 1 + ) + + self.assertEqual( + GenreToBand.count() + .where( + GenreToBand.band.name == "Pythonistas", + GenreToBand.genre.name == "Classical", + ) + .run_sync(), + 1, + ) + + def test_get_m2m(self): + """ + Make sure we can get related items via the joining table. + """ + + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None + + genres = band.get_m2m(Band.genres).run_sync() + + self.assertTrue(all(isinstance(i, Table) for i in genres)) + + self.assertEqual([i.name for i in genres], ["Rock", "Folk"]) + + def test_get_m2m_no_rows(self): + """ + If there are no matching objects, then an empty list should be + returned. + + https://github.com/piccolo-orm/piccolo/issues/1090 + + """ + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None + + Genre.delete(force=True).run_sync() + + genres = band.get_m2m(Band.genres).run_sync() + self.assertEqual(genres, []) + + def test_remove_m2m(self): + """ + Make sure we can remove related items via the joining table. + """ + + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None + + genre = Genre.objects().get(Genre.name == "Rock").run_sync() + assert genre is not None + + band.remove_m2m(genre, m2m=Band.genres).run_sync() + + self.assertEqual( + GenreToBand.count() + .where( + GenreToBand.band.name == "Pythonistas", + GenreToBand.genre.name == "Rock", + ) + .run_sync(), + 0, + ) + + # Make sure the others weren't removed: + self.assertEqual( + GenreToBand.count() + .where( + GenreToBand.band.name == "Pythonistas", + GenreToBand.genre.name == "Folk", + ) + .run_sync(), + 1, + ) + + self.assertEqual( + GenreToBand.count() + .where( + GenreToBand.band.name == "C-Sharps", + GenreToBand.genre.name == "Rock", + ) + .run_sync(), + 1, + ) diff --git a/tests/columns/m2m/test_m2m.py b/tests/columns/m2m/test_m2m.py new file mode 100644 index 000000000..c2b9d1f42 --- /dev/null +++ b/tests/columns/m2m/test_m2m.py @@ -0,0 +1,422 @@ +import asyncio +import datetime +import decimal +import uuid +from unittest import TestCase + +from piccolo.utils.encoding import JSONDict +from tests.base import engines_skip + +try: + from asyncpg.pgproto.pgproto import UUID as asyncpgUUID +except ImportError: + # In case someone is running the tests for SQLite and doesn't have asyncpg + # installed. + from uuid import UUID as asyncpgUUID + +from piccolo.columns.column_types import ( + JSON, + JSONB, + UUID, + Array, + BigInt, + Boolean, + Bytea, + Date, + DoublePrecision, + ForeignKey, + Integer, + Interval, + LazyTableReference, + Numeric, + Real, + SmallInt, + Text, + Timestamp, + Timestamptz, + Varchar, +) +from piccolo.columns.m2m import M2M +from piccolo.engine.finder import engine_finder +from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync + +from .base import M2MBase + +engine = engine_finder() + + +class TestM2M(M2MBase, TestCase): + def setUp(self): + return self._setUp(schema=None) + + +############################################################################### + +# A schema using custom primary keys + + +class Customer(Table): + uuid = UUID(primary_key=True) + name = Varchar() + concerts = M2M( + LazyTableReference("CustomerToConcert", module_path=__name__) + ) + + +class Concert(Table): + uuid = UUID(primary_key=True) + name = Varchar() + customers = M2M( + LazyTableReference("CustomerToConcert", module_path=__name__) + ) + + +class CustomerToConcert(Table): + customer = ForeignKey(Customer) + concert = ForeignKey(Concert) + + +CUSTOM_PK_SCHEMA = [Customer, Concert, CustomerToConcert] + + +class TestM2MCustomPrimaryKey(TestCase): + """ + Make sure the M2M functionality works correctly when the tables have custom + primary key columns. + """ + + def setUp(self): + create_db_tables_sync(*CUSTOM_PK_SCHEMA, if_not_exists=True) + + bob = Customer.objects().create(name="Bob").run_sync() + sally = Customer.objects().create(name="Sally").run_sync() + fred = Customer.objects().create(name="Fred").run_sync() + + rockfest = Concert.objects().create(name="Rockfest").run_sync() + folkfest = Concert.objects().create(name="Folkfest").run_sync() + classicfest = Concert.objects().create(name="Classicfest").run_sync() + + CustomerToConcert.insert( + CustomerToConcert(customer=bob, concert=rockfest), + CustomerToConcert(customer=bob, concert=classicfest), + CustomerToConcert(customer=sally, concert=rockfest), + CustomerToConcert(customer=sally, concert=folkfest), + CustomerToConcert(customer=fred, concert=classicfest), + ).run_sync() + + def tearDown(self): + drop_db_tables_sync(*CUSTOM_PK_SCHEMA) + + @engines_skip("cockroach") + def test_select(self): + """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + """ # noqa: E501 + response = Customer.select( + Customer.name, Customer.concerts(Concert.name, as_list=True) + ).run_sync() + + self.assertListEqual( + response, + [ + {"name": "Bob", "concerts": ["Rockfest", "Classicfest"]}, + {"name": "Sally", "concerts": ["Rockfest", "Folkfest"]}, + {"name": "Fred", "concerts": ["Classicfest"]}, + ], + ) + + # Now try it in reverse. + response = Concert.select( + Concert.name, Concert.customers(Customer.name, as_list=True) + ).run_sync() + + self.assertListEqual( + response, + [ + {"name": "Rockfest", "customers": ["Bob", "Sally"]}, + {"name": "Folkfest", "customers": ["Sally"]}, + {"name": "Classicfest", "customers": ["Bob", "Fred"]}, + ], + ) + + def test_add_m2m(self): + """ + Make sure we can add items to the joining table. + """ + customer = Customer.objects().get(Customer.name == "Bob").run_sync() + assert customer is not None + customer.add_m2m( + Concert(name="Jazzfest"), m2m=Customer.concerts + ).run_sync() + + self.assertTrue( + Concert.exists().where(Concert.name == "Jazzfest").run_sync() + ) + + self.assertEqual( + CustomerToConcert.count() + .where( + CustomerToConcert.customer.name == "Bob", + CustomerToConcert.concert.name == "Jazzfest", + ) + .run_sync(), + 1, + ) + + def test_add_m2m_within_transaction(self): + """ + Make sure we can add items to the joining table, when within an + existing transaction. + + https://github.com/piccolo-orm/piccolo/issues/674 + + """ + engine = Customer._meta.db + + async def add_m2m_in_transaction(): + async with engine.transaction(): + customer = await Customer.objects().get(Customer.name == "Bob") + assert customer is not None + await customer.add_m2m( + Concert(name="Jazzfest"), m2m=Customer.concerts + ) + + asyncio.run(add_m2m_in_transaction()) + + self.assertTrue( + Concert.exists().where(Concert.name == "Jazzfest").run_sync() + ) + + self.assertEqual( + CustomerToConcert.count() + .where( + CustomerToConcert.customer.name == "Bob", + CustomerToConcert.concert.name == "Jazzfest", + ) + .run_sync(), + 1, + ) + + def test_get_m2m(self): + """ + Make sure we can get related items via the joining table. + """ + customer = Customer.objects().get(Customer.name == "Bob").run_sync() + assert customer is not None + + concerts = customer.get_m2m(Customer.concerts).run_sync() + + self.assertTrue(all(isinstance(i, Table) for i in concerts)) + + self.assertCountEqual( + [i.name for i in concerts], ["Rockfest", "Classicfest"] + ) + + +############################################################################### + +# Test a very complex schema + + +class SmallTable(Table): + varchar_col = Varchar() + mega_rows = M2M(LazyTableReference("SmallToMega", module_path=__name__)) + + +if engine.engine_type != "cockroach": # type: ignore + + class MegaTable(Table): # type: ignore + """ + A table containing all of the column types and different column kwargs + """ + + array_col = Array(Varchar()) + bigint_col = BigInt() + boolean_col = Boolean() + bytea_col = Bytea() + date_col = Date() + double_precision_col = DoublePrecision() + integer_col = Integer() + interval_col = Interval() + json_col = JSON() + jsonb_col = JSONB() + numeric_col = Numeric(digits=(5, 2)) + real_col = Real() + smallint_col = SmallInt() + text_col = Text() + timestamp_col = Timestamp() + timestamptz_col = Timestamptz() + uuid_col = UUID() + varchar_col = Varchar() + +else: + + class MegaTable(Table): # type: ignore + """ + Special version for Cockroach. + A table containing all of the column types and different column kwargs + """ + + array_col = Array(Varchar()) + bigint_col = BigInt() + boolean_col = Boolean() + bytea_col = Bytea() + date_col = Date() + double_precision_col = DoublePrecision() + integer_col = BigInt() + interval_col = Interval() + json_col = JSONB() + jsonb_col = JSONB() + numeric_col = Numeric(digits=(5, 2)) + real_col = Real() + smallint_col = SmallInt() + text_col = Text() + timestamp_col = Timestamp() + timestamptz_col = Timestamptz() + uuid_col = UUID() + varchar_col = Varchar() + + +class SmallToMega(Table): + small = ForeignKey(MegaTable) + mega = ForeignKey(SmallTable) + + +COMPLEX_SCHEMA = [MegaTable, SmallTable, SmallToMega] + + +class TestM2MComplexSchema(TestCase): + """ + By using a very complex schema containing every column type, we can catch + more edge cases. + """ + + def setUp(self): + create_db_tables_sync(*COMPLEX_SCHEMA, if_not_exists=True) + + small_table = SmallTable(varchar_col="Test") + small_table.save().run_sync() + + mega_table = MegaTable( + array_col=["bob", "sally"], + bigint_col=1, + boolean_col=True, + bytea_col="hello".encode("utf8"), + date_col=datetime.date(year=2021, month=1, day=1), + double_precision_col=1.344, + integer_col=1, + interval_col=datetime.timedelta(seconds=10), + json_col={"a": 1}, + jsonb_col={"a": 1}, + numeric_col=decimal.Decimal("1.1"), + real_col=1.1, + smallint_col=1, + text_col="hello", + timestamp_col=datetime.datetime(year=2021, month=1, day=1), + timestamptz_col=datetime.datetime( + year=2021, month=1, day=1, tzinfo=datetime.timezone.utc + ), + uuid_col=uuid.UUID("12783854-c012-4c15-8183-8eecb46f2c4e"), + varchar_col="hello", + ) + mega_table.save().run_sync() + + SmallToMega(small=small_table, mega=mega_table).save().run_sync() + + self.mega_table = mega_table + + def tearDown(self): + drop_db_tables_sync(*COMPLEX_SCHEMA) + + @engines_skip("cockroach") + def test_select_all(self): + """ + Fetch all of the columns from the related table to make sure they're + returned correctly. + """ + """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + """ # noqa: E501 + response = SmallTable.select( + SmallTable.varchar_col, SmallTable.mega_rows(load_json=True) + ).run_sync() + + self.assertEqual(len(response), 1) + mega_rows = response[0]["mega_rows"] + + self.assertEqual(len(mega_rows), 1) + mega_row = mega_rows[0] + + for key, value in mega_row.items(): + # Make sure that every value in the response matches what we saved. + self.assertAlmostEqual( + getattr(self.mega_table, key), + value, + msg=f"{key} doesn't match", + ) + + @engines_skip("cockroach") + def test_select_single(self): + """ + Make sure each column can be selected one at a time. + """ + """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + """ # noqa: E501 + for column in MegaTable._meta.columns: + response = SmallTable.select( + SmallTable.varchar_col, + SmallTable.mega_rows(column, load_json=True), + ).run_sync() + + data = response[0]["mega_rows"][0] + column_name = column._meta.name + + original_value = getattr(self.mega_table, column_name) + returned_value = data[column_name] + + if isinstance(column, UUID): + self.assertIn(type(returned_value), (uuid.UUID, asyncpgUUID)) + elif isinstance(column, (JSON, JSONB)): + self.assertEqual(type(returned_value), JSONDict) + self.assertEqual(original_value, returned_value) + else: + self.assertEqual( + type(original_value), + type(returned_value), + msg=f"{column_name} type isn't correct", + ) + + self.assertAlmostEqual( + original_value, + returned_value, + msg=f"{column_name} doesn't match", + ) + + # Test it as a list too + response = SmallTable.select( + SmallTable.varchar_col, + SmallTable.mega_rows(column, as_list=True, load_json=True), + ).run_sync() + + original_value = getattr(self.mega_table, column_name) + returned_value = response[0]["mega_rows"][0] + + if isinstance(column, UUID): + self.assertIn(type(returned_value), (uuid.UUID, asyncpgUUID)) + self.assertEqual(str(original_value), str(returned_value)) + elif isinstance(column, (JSON, JSONB)): + self.assertEqual(type(returned_value), JSONDict) + self.assertEqual(original_value, returned_value) + else: + self.assertEqual( + type(original_value), + type(returned_value), + msg=f"{column_name} type isn't correct", + ) + + self.assertAlmostEqual( + original_value, + returned_value, + msg=f"{column_name} doesn't match", + ) diff --git a/tests/columns/m2m/test_m2m_schema.py b/tests/columns/m2m/test_m2m_schema.py new file mode 100644 index 000000000..01ed90681 --- /dev/null +++ b/tests/columns/m2m/test_m2m_schema.py @@ -0,0 +1,16 @@ +from unittest import TestCase + +from tests.base import engines_skip + +from .base import M2MBase + + +@engines_skip("sqlite") +class TestM2MWithSchema(M2MBase, TestCase): + """ + Make sure that when the tables exist in a non-public schema, that M2M still + works. + """ + + def setUp(self): + return self._setUp(schema="schema_1") diff --git a/tests/columns/test_array.py b/tests/columns/test_array.py index 7e0141925..d347d0fe5 100644 --- a/tests/columns/test_array.py +++ b/tests/columns/test_array.py @@ -1,88 +1,546 @@ +import datetime from unittest import TestCase -from piccolo.columns.column_types import Array, Integer +import pytest + +from piccolo.columns.column_types import ( + Array, + BigInt, + Date, + Integer, + Time, + Timestamp, + Timestamptz, +) +from piccolo.querystring import QueryString from piccolo.table import Table -from tests.base import postgres_only +from piccolo.testing.test_case import TableTest +from tests.base import engines_only, engines_skip, sqlite_only class MyTable(Table): value = Array(base_column=Integer()) -class TestArrayPostgres(TestCase): +class TestArrayDefault(TestCase): + def test_array_default(self): + """ + We use ``ListProxy`` instead of ``list`` as a default, because of + issues with Sphinx's autodoc. Make sure it's correctly converted to a + plain ``list`` in ``Array.__init__``. + """ + column = Array(base_column=Integer()) + self.assertTrue(column.default is list) + + +class TestArray(TableTest): """ - Make sure an Array column can be created. + Make sure an Array column can be created, and works correctly. """ - def setUp(self): - MyTable.create_table().run_sync() - - def tearDown(self): - MyTable.alter().drop_table().run_sync() + tables = [MyTable] + @pytest.mark.cockroach_array_slow def test_storage(self): """ Make sure data can be stored and retrieved. - """ + + In CockroachDB <= v22.2.0 we had this error: + + * https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + + In newer CockroachDB versions, it runs but is very slow: + + * https://github.com/piccolo-orm/piccolo/issues/1005 + + """ # noqa: E501 MyTable(value=[1, 2, 3]).save().run_sync() row = MyTable.objects().first().run_sync() + assert row is not None self.assertEqual(row.value, [1, 2, 3]) - @postgres_only + @engines_skip("sqlite") + @pytest.mark.cockroach_array_slow def test_index(self): """ Indexes should allow individual array elements to be queried. - """ + + In CockroachDB <= v22.2.0 we had this error: + + * https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + + In newer CockroachDB versions, it runs but is very slow: + + * https://github.com/piccolo-orm/piccolo/issues/1005 + + """ # noqa: E501 MyTable(value=[1, 2, 3]).save().run_sync() self.assertEqual( MyTable.select(MyTable.value[0]).first().run_sync(), {"value": 1} ) - @postgres_only + @engines_skip("sqlite") + @pytest.mark.cockroach_array_slow def test_all(self): """ Make sure rows can be retrieved where all items in an array match a given value. - """ + + In CockroachDB <= v22.2.0 we had this error: + + * https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + + In newer CockroachDB versions, it runs but is very slow: + + * https://github.com/piccolo-orm/piccolo/issues/1005 + + """ # noqa: E501 MyTable(value=[1, 1, 1]).save().run_sync() + # We have to explicitly specify the type, so CockroachDB works. self.assertEqual( MyTable.select(MyTable.value) - .where(MyTable.value.all(1)) + .where(MyTable.value.all(QueryString("{}::INTEGER", 1))) .first() .run_sync(), {"value": [1, 1, 1]}, ) + # We have to explicitly specify the type, so CockroachDB works. self.assertEqual( MyTable.select(MyTable.value) - .where(MyTable.value.all(0)) + .where(MyTable.value.all(QueryString("{}::INTEGER", 0))) .first() .run_sync(), None, ) + @engines_skip("sqlite") + @pytest.mark.cockroach_array_slow def test_any(self): """ Make sure rows can be retrieved where any items in an array match a given value. - """ + + In CockroachDB <= v22.2.0 we had this error: + + * https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + + In newer CockroachDB versions, it runs but is very slow: + + * https://github.com/piccolo-orm/piccolo/issues/1005 + + """ # noqa: E501 + MyTable(value=[1, 2, 3]).save().run_sync() + # We have to explicitly specify the type, so CockroachDB works. self.assertEqual( MyTable.select(MyTable.value) - .where(MyTable.value.any(1)) + .where(MyTable.value.any(QueryString("{}::INTEGER", 1))) .first() .run_sync(), {"value": [1, 2, 3]}, ) + # We have to explicitly specify the type, so CockroachDB works. self.assertEqual( MyTable.select(MyTable.value) - .where(MyTable.value.any(0)) + .where(MyTable.value.any(QueryString("{}::INTEGER", 0))) .first() .run_sync(), None, ) + + @engines_skip("sqlite") + @pytest.mark.cockroach_array_slow + def test_not_any(self): + """ + Make sure rows can be retrieved where the array doesn't contain a + certain value. + + In CockroachDB <= v22.2.0 we had this error: + + * https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + + In newer CockroachDB versions, it runs but is very slow: + + * https://github.com/piccolo-orm/piccolo/issues/1005 + + """ # noqa: E501 + + MyTable(value=[1, 2, 3]).save().run_sync() + MyTable(value=[4, 5, 6]).save().run_sync() + + # We have to explicitly specify the type, so CockroachDB works. + self.assertEqual( + MyTable.select(MyTable.value) + .where(MyTable.value.not_any(QueryString("{}::INTEGER", 4))) + .run_sync(), + [{"value": [1, 2, 3]}], + ) + + @engines_skip("sqlite") + @pytest.mark.cockroach_array_slow + def test_cat(self): + """ + Make sure values can be appended to an array and that we can concatenate two arrays. + + In CockroachDB <= v22.2.0 we had this error: + + * https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + + In newer CockroachDB versions, it runs but is very slow: + + * https://github.com/piccolo-orm/piccolo/issues/1005 + + """ # noqa: E501 + MyTable(value=[5]).save().run_sync() + + MyTable.update( + {MyTable.value: MyTable.value.cat([6])}, force=True + ).run_sync() + + self.assertEqual( + MyTable.select(MyTable.value).run_sync(), + [{"value": [5, 6]}], + ) + + # Try plus symbol - add array to the end + + MyTable.update( + {MyTable.value: MyTable.value + [7]}, force=True + ).run_sync() + + self.assertEqual( + MyTable.select(MyTable.value).run_sync(), + [{"value": [5, 6, 7]}], + ) + + # Add array to the start + + MyTable.update( + {MyTable.value: [4] + MyTable.value}, force=True + ).run_sync() + + self.assertEqual( + MyTable.select(MyTable.value).run_sync(), + [{"value": [4, 5, 6, 7]}], + ) + + # Add array to the start and end + MyTable.update( + {MyTable.value: [3] + MyTable.value + [8]}, force=True + ).run_sync() + + self.assertEqual( + MyTable.select(MyTable.value).run_sync(), + [{"value": [3, 4, 5, 6, 7, 8]}], + ) + + @sqlite_only + def test_cat_sqlite(self): + """ + If using SQLite then an exception should be raised currently. + """ + with self.assertRaises(ValueError) as manager: + MyTable.value.cat([2]) + + self.assertEqual( + str(manager.exception), + "Only Postgres and Cockroach support array concatenation.", + ) + + @engines_skip("sqlite") + @pytest.mark.cockroach_array_slow + def test_prepend(self): + """ + Make sure values can be added to the beginning of the array. + + In CockroachDB <= v22.2.0 we had this error: + + * https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + + In newer CockroachDB versions, it runs but is very slow: + + * https://github.com/piccolo-orm/piccolo/issues/1005 + + """ # noqa: E501 + MyTable(value=[1, 1, 1]).save().run_sync() + + MyTable.update( + {MyTable.value: MyTable.value.prepend(3)}, force=True + ).run_sync() + + self.assertEqual( + MyTable.select(MyTable.value).run_sync(), + [{"value": [3, 1, 1, 1]}], + ) + + @sqlite_only + def test_prepend_sqlite(self): + """ + If using SQLite then an exception should be raised currently. + """ + with self.assertRaises(ValueError) as manager: + MyTable.value.prepend(2) + + self.assertEqual( + str(manager.exception), + "Only Postgres and Cockroach support array prepending.", + ) + + @engines_skip("sqlite") + @pytest.mark.cockroach_array_slow + def test_append(self): + """ + Make sure values can be appended to an array. + + In CockroachDB <= v22.2.0 we had this error: + + * https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + + In newer CockroachDB versions, it runs but is very slow: + + * https://github.com/piccolo-orm/piccolo/issues/1005 + + """ # noqa: E501 + MyTable(value=[1, 1, 1]).save().run_sync() + + MyTable.update( + {MyTable.value: MyTable.value.append(3)}, force=True + ).run_sync() + + self.assertEqual( + MyTable.select(MyTable.value).run_sync(), + [{"value": [1, 1, 1, 3]}], + ) + + @sqlite_only + def test_append_sqlite(self): + """ + If using SQLite then an exception should be raised currently. + """ + with self.assertRaises(ValueError) as manager: + MyTable.value.append(2) + + self.assertEqual( + str(manager.exception), + "Only Postgres and Cockroach support array appending.", + ) + + @engines_skip("sqlite") + @pytest.mark.cockroach_array_slow + def test_replace(self): + """ + Make sure values can be swapped in the array. + + In CockroachDB <= v22.2.0 we had this error: + + * https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + + In newer CockroachDB versions, it runs but is very slow: + + * https://github.com/piccolo-orm/piccolo/issues/1005 + + """ # noqa: E501 + MyTable(value=[1, 1, 1]).save().run_sync() + + MyTable.update( + {MyTable.value: MyTable.value.replace(1, 2)}, force=True + ).run_sync() + + self.assertEqual( + MyTable.select(MyTable.value).run_sync(), + [{"value": [2, 2, 2]}], + ) + + @sqlite_only + def test_replace_sqlite(self): + """ + If using SQLite then an exception should be raised currently. + """ + with self.assertRaises(ValueError) as manager: + MyTable.value.replace(1, 2) + + self.assertEqual( + str(manager.exception), + "Only Postgres and Cockroach support array substitution.", + ) + + @engines_skip("sqlite") + @pytest.mark.cockroach_array_slow + def test_remove(self): + """ + Make sure values can be removed from an array. + + In CockroachDB <= v22.2.0 we had this error: + + * https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + + In newer CockroachDB versions, it runs but is very slow: + + * https://github.com/piccolo-orm/piccolo/issues/1005 + + """ # noqa: E501 + MyTable(value=[1, 2, 3]).save().run_sync() + + MyTable.update( + {MyTable.value: MyTable.value.remove(2)}, force=True + ).run_sync() + + self.assertEqual( + MyTable.select(MyTable.value).run_sync(), + [{"value": [1, 3]}], + ) + + @sqlite_only + def test_remove_sqlite(self): + """ + If using SQLite then an exception should be raised currently. + """ + with self.assertRaises(ValueError) as manager: + MyTable.value.remove(2) + + self.assertEqual( + str(manager.exception), + "Only Postgres and Cockroach support array removing.", + ) + + +############################################################################### +# Date and time arrays + + +class DateTimeArrayTable(Table): + date = Array(Date()) + time = Array(Time()) + timestamp = Array(Timestamp()) + timestamptz = Array(Timestamptz()) + date_nullable = Array(Date(), null=True) + time_nullable = Array(Time(), null=True) + timestamp_nullable = Array(Timestamp(), null=True) + timestamptz_nullable = Array(Timestamptz(), null=True) + + +class TestDateTimeArray(TestCase): + """ + Make sure that data can be stored and retrieved when using arrays of + date / time / timestamp. + + We have to serialise / deserialise it in a special way in SQLite, hence + the tests. + + """ + + def setUp(self): + DateTimeArrayTable.create_table().run_sync() + + def tearDown(self): + DateTimeArrayTable.alter().drop_table().run_sync() + + @engines_only("postgres", "sqlite") + def test_storage(self): + test_date = datetime.date(year=2024, month=1, day=1) + test_time = datetime.time(hour=12, minute=0) + test_timestamp = datetime.datetime( + year=2024, month=1, day=1, hour=12, minute=0 + ) + test_timestamptz = datetime.datetime( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + tzinfo=datetime.timezone.utc, + ) + + DateTimeArrayTable( + { + DateTimeArrayTable.date: [test_date], + DateTimeArrayTable.time: [test_time], + DateTimeArrayTable.timestamp: [test_timestamp], + DateTimeArrayTable.timestamptz: [test_timestamptz], + DateTimeArrayTable.date_nullable: None, + DateTimeArrayTable.time_nullable: None, + DateTimeArrayTable.timestamp_nullable: None, + DateTimeArrayTable.timestamptz_nullable: None, + } + ).save().run_sync() + + row = DateTimeArrayTable.objects().first().run_sync() + assert row is not None + + self.assertListEqual(row.date, [test_date]) + self.assertListEqual(row.time, [test_time]) + self.assertListEqual(row.timestamp, [test_timestamp]) + self.assertListEqual(row.timestamptz, [test_timestamptz]) + + self.assertIsNone(row.date_nullable) + self.assertIsNone(row.time_nullable) + self.assertIsNone(row.timestamp_nullable) + self.assertIsNone(row.timestamptz_nullable) + + +############################################################################### +# Nested arrays + + +class NestedArrayTable(Table): + value = Array(base_column=Array(base_column=BigInt())) + + +class TestNestedArray(TestCase): + """ + Make sure that tables with nested arrays can be created, and work + correctly. + + Related to this bug, with nested array columns containing ``BigInt``: + + https://github.com/piccolo-orm/piccolo/issues/936 + + """ + + def setUp(self): + NestedArrayTable.create_table().run_sync() + + def tearDown(self): + NestedArrayTable.alter().drop_table().run_sync() + + @engines_only("postgres", "sqlite") + def test_storage(self): + """ + Make sure data can be stored and retrieved. + + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + + """ # noqa: E501 + NestedArrayTable(value=[[1, 2, 3], [4, 5, 6]]).save().run_sync() + + row = NestedArrayTable.objects().first().run_sync() + assert row is not None + self.assertEqual(row.value, [[1, 2, 3], [4, 5, 6]]) + + +class TestGetDimensions(TestCase): + def test_get_dimensions(self): + """ + Make sure that `_get_dimensions` returns the correct value. + """ + self.assertEqual(Array(Integer())._get_dimensions(), 1) + self.assertEqual(Array(Array(Integer()))._get_dimensions(), 2) + self.assertEqual(Array(Array(Array(Integer())))._get_dimensions(), 3) + + +class TestGetInnerValueType(TestCase): + def test_get_inner_value_type(self): + """ + Make sure that `_get_inner_value_type` returns the correct base type. + """ + self.assertEqual(Array(Integer())._get_inner_value_type(), int) + self.assertEqual(Array(Array(Integer()))._get_inner_value_type(), int) + self.assertEqual( + Array(Array(Array(Integer())))._get_inner_value_type(), int + ) diff --git a/tests/columns/test_base.py b/tests/columns/test_base.py index 714b0dd0d..2ef4c2a8c 100644 --- a/tests/columns/test_base.py +++ b/tests/columns/test_base.py @@ -4,31 +4,13 @@ from piccolo.columns.choices import Choice from piccolo.columns.column_types import Integer, Varchar from piccolo.table import Table +from tests.example_apps.music.tables import Band, Manager class MyTable(Table): name = Varchar() -class TestColumn(TestCase): - def test_like_raises(self): - """ - Make sure an invalid 'like' argument raises an exception. Should - contain a % symbol. - """ - column = MyTable.name - with self.assertRaises(ValueError): - column.like("guido") - - with self.assertRaises(ValueError): - column.ilike("guido") - - # Make sure valid args don't raise an exception. - for arg in ["%guido", "guido%", "%guido%"]: - column.like("%foo") - column.ilike("foo%") - - class TestCopy(TestCase): def test_copy(self): """ @@ -50,7 +32,17 @@ def test_help_text(self): """ help_text = "This is some important help text for users." column = Varchar(help_text=help_text) - self.assertTrue(column._meta.help_text == help_text) + self.assertEqual(column._meta.help_text, help_text) + + +class TestSecretParameter(TestCase): + def test_secret_parameter(self): + """ + Test adding secret parameter to a column. + """ + secret = False + column = Varchar(secret=secret) + self.assertEqual(column._meta.secret, secret) class TestChoices(TestCase): @@ -121,3 +113,35 @@ class Title(Enum): "mrs": {"display_name": "Mrs.", "value": 2}, }, ) + + +class TestEquals(TestCase): + def test_non_column(self): + """ + Make sure non-column values don't match. + """ + for value in (1, "abc", None): + self.assertFalse(Manager.name._equals(value)) # type: ignore + + def test_equals(self): + """ + Test basic usage. + """ + self.assertTrue(Manager.name._equals(Manager.name)) + + def test_same_name(self): + """ + Make sure that columns with the same name, but on different tables, + don't match. + """ + self.assertFalse(Manager.name._equals(Band.name)) + + def test_including_joins(self): + """ + Make sure `including_joins` arg works correctly. + """ + self.assertTrue(Band.manager.name._equals(Manager.name)) + + self.assertFalse( + Band.manager.name._equals(Manager.name, including_joins=True) + ) diff --git a/tests/columns/test_bigint.py b/tests/columns/test_bigint.py index 41dc71a2a..9cb1b8ae4 100644 --- a/tests/columns/test_bigint.py +++ b/tests/columns/test_bigint.py @@ -1,31 +1,26 @@ import os -from unittest import TestCase from piccolo.columns.column_types import BigInt from piccolo.table import Table - -from ..base import postgres_only +from piccolo.testing.test_case import TableTest +from tests.base import engines_only class MyTable(Table): value = BigInt() -@postgres_only -class TestBigIntPostgres(TestCase): +@engines_only("postgres", "cockroach") +class TestBigIntPostgres(TableTest): """ Make sure a BigInt column in Postgres can store a large number. """ - def setUp(self): - MyTable.create_table().run_sync() - - def tearDown(self): - MyTable.alter().drop_table().run_sync() + tables = [MyTable] def _test_length(self): # Can store 8 bytes, but split between positive and negative values. - max_value = int(2 ** 64 / 2) - 1 + max_value = int(2**64 / 2) - 1 min_value = max_value * -1 print("Testing max value") diff --git a/tests/columns/test_boolean.py b/tests/columns/test_boolean.py index eea3df8d0..2cba15767 100644 --- a/tests/columns/test_boolean.py +++ b/tests/columns/test_boolean.py @@ -1,37 +1,39 @@ -from unittest import TestCase +from typing import Any from piccolo.columns.column_types import Boolean from piccolo.table import Table +from piccolo.testing.test_case import TableTest class MyTable(Table): - boolean = Boolean(boolean=False, null=True) + boolean = Boolean(default=False, null=True) -class TestBoolean(TestCase): - def setUp(self): - MyTable.create_table().run_sync() +class TestBoolean(TableTest): + tables = [MyTable] - def tearDown(self): - MyTable.alter().drop_table().run_sync() - - def test_return_type(self): + def test_return_type(self) -> None: for value in (True, False, None, ...): - kwargs = {} if value is ... else {"boolean": value} + kwargs: dict[str, Any] = {} if value is ... else {"boolean": value} expected = MyTable.boolean.default if value is ... else value row = MyTable(**kwargs) row.save().run_sync() self.assertEqual(row.boolean, expected) - self.assertEqual( + row_from_db = ( MyTable.select(MyTable.boolean) .where( MyTable._meta.primary_key == getattr(row, MyTable._meta.primary_key._meta.name) ) .first() - .run_sync()["boolean"], + .run_sync() + ) + assert row_from_db is not None + + self.assertEqual( + row_from_db["boolean"], expected, ) diff --git a/tests/columns/test_bytea.py b/tests/columns/test_bytea.py index 666a5a8d6..8114e9325 100644 --- a/tests/columns/test_bytea.py +++ b/tests/columns/test_bytea.py @@ -1,7 +1,6 @@ -from unittest import TestCase - from piccolo.columns.column_types import Bytea from piccolo.table import Table +from piccolo.testing.test_case import TableTest class MyTable(Table): @@ -19,12 +18,8 @@ class MyTableDefault(Table): token_none = Bytea(default=None, null=True) -class TestBytea(TestCase): - def setUp(self): - MyTable.create_table().run_sync() - - def tearDown(self): - MyTable.alter().drop_table().run_sync() +class TestBytea(TableTest): + tables = [MyTable] def test_bytea(self): """ @@ -40,12 +35,8 @@ def test_bytea(self): ) -class TestByteaDefault(TestCase): - def setUp(self): - MyTableDefault.create_table().run_sync() - - def tearDown(self): - MyTableDefault.alter().drop_table().run_sync() +class TestByteaDefault(TableTest): + tables = [MyTableDefault] def test_json_default(self): row = MyTableDefault() @@ -59,4 +50,4 @@ def test_json_default(self): def test_invalid_default(self): with self.assertRaises(ValueError): for value in ("a", 1, ("x", "y", "z")): - Bytea(default=value) + Bytea(default=value) # type: ignore diff --git a/tests/columns/test_choices.py b/tests/columns/test_choices.py index 4be9f9921..d3e1822e5 100644 --- a/tests/columns/test_choices.py +++ b/tests/columns/test_choices.py @@ -1,8 +1,15 @@ -from tests.base import DBTestCase -from tests.example_app.tables import Shirt +import enum +from piccolo.columns.column_types import Array, Varchar +from piccolo.table import Table +from piccolo.testing.test_case import TableTest +from tests.base import engines_only +from tests.example_apps.music.tables import Shirt + + +class TestChoices(TableTest): + tables = [Shirt] -class TestChoices(DBTestCase): def _insert_shirts(self): Shirt.insert( Shirt(size=Shirt.Size.small), @@ -23,6 +30,7 @@ def test_default(self): """ Shirt().save().run_sync() shirt = Shirt.objects().first().run_sync() + assert shirt is not None self.assertEqual(shirt.size, "l") def test_update(self): @@ -63,3 +71,71 @@ def test_objects_where(self): ) self.assertEqual(len(shirts), 1) self.assertEqual(shirts[0].size, "s") + + +class Ticket(Table): + class Extras(str, enum.Enum): + drink = "drink" + snack = "snack" + program = "program" + + extras = Array(Varchar(), choices=Extras) + + +@engines_only("postgres", "sqlite") +class TestArrayChoices(TableTest): + """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg + """ # noqa: E501 + + tables = [Ticket] + + def test_string(self): + """ + Make sure strings can be passed in as choices. + """ + ticket = Ticket(extras=["drink", "snack", "program"]) + ticket.save().run_sync() + + self.assertListEqual( + Ticket.select(Ticket.extras).run_sync(), + [{"extras": ["drink", "snack", "program"]}], + ) + + def test_enum(self): + """ + Make sure enums can be passed in as choices. + """ + ticket = Ticket( + extras=[ + Ticket.Extras.drink, + Ticket.Extras.snack, + Ticket.Extras.program, + ] + ) + ticket.save().run_sync() + + self.assertListEqual( + Ticket.select(Ticket.extras).run_sync(), + [{"extras": ["drink", "snack", "program"]}], + ) + + def test_invalid_choices(self): + """ + Make sure an invalid choices Enum is rejected. + """ + with self.assertRaises(ValueError) as manager: + + class Ticket(Table): + # This will be rejected, because the values are ints, and the + # Array's base_column is Varchar. + class Extras(int, enum.Enum): + drink = 1 + snack = 2 + program = 3 + + extras = Array(Varchar(), choices=Extras) + + self.assertEqual( + manager.exception.__str__(), "drink doesn't have the correct type" + ) diff --git a/tests/columns/test_combination.py b/tests/columns/test_combination.py index c580e598e..4f36bf99f 100644 --- a/tests/columns/test_combination.py +++ b/tests/columns/test_combination.py @@ -1,36 +1,84 @@ import unittest -from ..example_app.tables import Band +from tests.example_apps.music.tables import Band, Concert class TestWhere(unittest.TestCase): def test_equals(self): _where = Band.name == "Pythonistas" sql = _where.__str__() - self.assertEqual(sql, "band.name = 'Pythonistas'") + self.assertEqual(sql, '"band"."name" = \'Pythonistas\'') def test_not_equal(self): _where = Band.name != "CSharps" sql = _where.__str__() - self.assertEqual(sql, "band.name != 'CSharps'") + self.assertEqual(sql, '"band"."name" != \'CSharps\'') def test_like(self): _where = Band.name.like("Python%") sql = _where.__str__() - self.assertEqual(sql, "band.name LIKE 'Python%'") + self.assertEqual(sql, '"band"."name" LIKE \'Python%\'') def test_is_in(self): _where = Band.name.is_in(["Pythonistas", "Rustaceans"]) sql = _where.__str__() - self.assertEqual(sql, "band.name IN ('Pythonistas', 'Rustaceans')") + self.assertEqual( + sql, "\"band\".\"name\" IN ('Pythonistas', 'Rustaceans')" + ) with self.assertRaises(ValueError): Band.name.is_in([]) + def test_is_in_subquery(self): + _where = Band.id.is_in( + Concert.select(Concert.band_1).where(Concert.band_1 == 1) + ) + sql = _where.__str__() + self.assertEqual( + sql, + '"band"."id" IN (SELECT ALL "concert"."band_1" AS "band_1" FROM "concert" WHERE "concert"."band_1" = 1)', # noqa: E501 + ) + + # a sub select must only return a single column + with self.assertRaises(ValueError): + Band.id.is_in(Concert.select().where(Concert.band_1 == 1)) + def test_not_in(self): _where = Band.name.not_in(["CSharps"]) sql = _where.__str__() - self.assertEqual(sql, "band.name NOT IN ('CSharps')") + self.assertEqual(sql, '"band"."name" NOT IN (\'CSharps\')') with self.assertRaises(ValueError): Band.name.not_in([]) + + def test_not_in_subquery(self): + _where = Band.id.not_in( + Concert.select(Concert.band_1).where(Concert.band_1 == 1) + ) + sql = _where.__str__() + self.assertEqual( + sql, + '"band"."id" NOT IN (SELECT ALL "concert"."band_1" AS "band_1" FROM "concert" WHERE "concert"."band_1" = 1)', # noqa: E501 + ) + + # a sub select must only return a single column + with self.assertRaises(ValueError): + Band.id.not_in(Concert.select().where(Concert.band_1 == 1)) + + +class TestAnd(unittest.TestCase): + def test_get_column_values(self): + """ + Make sure that we can extract the column values from an ``And``. + + There was a bug with ``None`` values not working: + + https://github.com/piccolo-orm/piccolo/issues/715 + + """ + And_ = (Band.manager.is_null()) & (Band.name == "Pythonistas") + column_values = And_.get_column_values() + + self.assertDictEqual( + column_values, {Band.name: "Pythonistas", Band.manager: None} + ) diff --git a/tests/columns/test_date.py b/tests/columns/test_date.py index 3169d8bf6..1628c4758 100644 --- a/tests/columns/test_date.py +++ b/tests/columns/test_date.py @@ -1,9 +1,9 @@ import datetime -from unittest import TestCase from piccolo.columns.column_types import Date from piccolo.columns.defaults.date import DateNow from piccolo.table import Table +from piccolo.testing.test_case import TableTest class MyTable(Table): @@ -14,12 +14,8 @@ class MyTableDefault(Table): created_on = Date(default=DateNow()) -class TestDate(TestCase): - def setUp(self): - MyTable.create_table().run_sync() - - def tearDown(self): - MyTable.alter().drop_table().run_sync() +class TestDate(TableTest): + tables = [MyTable] def test_timestamp(self): created_on = datetime.datetime.now().date() @@ -27,15 +23,12 @@ def test_timestamp(self): row.save().run_sync() result = MyTable.objects().first().run_sync() + assert result is not None self.assertEqual(result.created_on, created_on) -class TestDateDefault(TestCase): - def setUp(self): - MyTableDefault.create_table().run_sync() - - def tearDown(self): - MyTableDefault.alter().drop_table().run_sync() +class TestDateDefault(TableTest): + tables = [MyTableDefault] def test_timestamp(self): created_on = datetime.datetime.now().date() @@ -43,4 +36,5 @@ def test_timestamp(self): row.save().run_sync() result = MyTableDefault.objects().first().run_sync() + assert result is not None self.assertEqual(result.created_on, created_on) diff --git a/tests/columns/test_db_column_name.py b/tests/columns/test_db_column_name.py new file mode 100644 index 000000000..1e4195d4e --- /dev/null +++ b/tests/columns/test_db_column_name.py @@ -0,0 +1,313 @@ +from typing import Optional + +from piccolo.columns.column_types import ForeignKey, Integer, Serial, Varchar +from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync +from tests.base import DBTestCase, engine_is, engines_only, engines_skip + + +class Manager(Table): + id: Serial + name = Varchar() + + +class Band(Table): + id: Serial + name = Varchar(db_column_name="regrettable_column_name") + popularity = Integer() + manager = ForeignKey(Manager, db_column_name="manager_fk") + + +class TestDBColumnName(DBTestCase): + """ + By using the ``db_column_name`` arg, the user can map a ``Column`` to a + database column with a different name. For example: + + .. code-block:: python + + class MyTable(Table): + class_ = Varchar(db_column_name='class') + + """ + + def setUp(self): + create_db_tables_sync(Band, Manager) + + def tearDown(self): + drop_db_tables_sync(Band, Manager) + + def insert_band(self, manager: Optional[Manager] = None) -> Band: + band = Band(name="Pythonistas", popularity=1000, manager=manager) + band.save().run_sync() + return band + + @engines_only("postgres", "cockroach") + def test_column_name_correct(self): + """ + Make sure the column has the correct name in the database. + """ + self.get_postgres_column_definition( + tablename="band", column_name="regrettable_column_name" + ) + + with self.assertRaises(ValueError): + self.get_postgres_column_definition( + tablename="band", column_name="name" + ) + + def test_save(self): + """ + Make sure save queries work correctly. + """ + self.insert_band() + + band_from_db = Band.objects().first().run_sync() + assert band_from_db is not None + self.assertEqual(band_from_db.name, "Pythonistas") + + def test_create(self): + """ + Make sure create queries work correctly. + """ + band = self.insert_band() + self.assertEqual(band.name, "Pythonistas") + + band_from_db = Band.objects().first().run_sync() + assert band_from_db is not None + self.assertEqual(band_from_db.name, "Pythonistas") + + def test_select(self): + """ + Make sure that select queries just return what is stored in the + database. We might add an option in the future which maps the column + name to it's alias, but it's hard to predict what behaviour the user + wants. + """ + self.insert_band() + + # Make sure we can select all columns + bands = Band.select().run_sync() + if engine_is("cockroach"): + self.assertEqual( + bands, + [ + { + "id": bands[0]["id"], + "regrettable_column_name": "Pythonistas", + "popularity": 1000, + "manager_fk": None, + } + ], + ) + else: + self.assertEqual( + bands, + [ + { + "id": 1, + "regrettable_column_name": "Pythonistas", + "popularity": 1000, + "manager_fk": None, + } + ], + ) + + # Make sure we can select a single column + bands = Band.select(Band.name).run_sync() + self.assertEqual( + bands, + [ + { + "regrettable_column_name": "Pythonistas", + } + ], + ) + + # Make sure aliases still work + bands = Band.select(Band.name.as_alias("name")).run_sync() + self.assertEqual( + bands, + [ + { + "name": "Pythonistas", + } + ], + ) + + def test_join(self): + """ + Make sure that foreign keys with a ``db_column_name`` specified still + work for joins. + + https://github.com/piccolo-orm/piccolo/issues/1101 + + """ + manager = Manager.objects().create(name="Guido").run_sync() + band = self.insert_band(manager=manager) + + bands = Band.select().where(Band.manager.name == "Guido").run_sync() + + self.assertListEqual( + bands, + [ + { + "id": band.id, + "manager_fk": manager.id, + "popularity": 1000, + "regrettable_column_name": "Pythonistas", + } + ], + ) + + def test_update(self): + """ + Make sure update queries work correctly. + """ + self.insert_band() + + Band.update({Band.name: "Pythonistas 2"}, force=True).run_sync() + + bands = Band.select().run_sync() + if engine_is("cockroach"): + self.assertEqual( + bands, + [ + { + "id": bands[0]["id"], + "regrettable_column_name": "Pythonistas 2", + "popularity": 1000, + "manager_fk": None, + } + ], + ) + else: + self.assertEqual( + bands, + [ + { + "id": 1, + "regrettable_column_name": "Pythonistas 2", + "popularity": 1000, + "manager_fk": None, + } + ], + ) + + Band.update({"name": "Pythonistas 3"}, force=True).run_sync() + + bands = Band.select().run_sync() + if engine_is("cockroach"): + self.assertEqual( + bands, + [ + { + "id": bands[0]["id"], + "regrettable_column_name": "Pythonistas 3", + "popularity": 1000, + "manager_fk": None, + } + ], + ) + else: + self.assertEqual( + bands, + [ + { + "id": 1, + "regrettable_column_name": "Pythonistas 3", + "popularity": 1000, + "manager_fk": None, + } + ], + ) + + @engines_skip("cockroach") + def test_delete(self): + """ + Make sure delete queries work correctly. + """ + Band.insert( + Band(name="Pythonistas", popularity=1000), + Band(name="Rustaceans", popularity=500), + ).run_sync() + + bands = Band.select().run_sync() + self.assertEqual( + bands, + [ + { + "id": 1, + "regrettable_column_name": "Pythonistas", + "popularity": 1000, + "manager_fk": None, + }, + { + "id": 2, + "regrettable_column_name": "Rustaceans", + "popularity": 500, + "manager_fk": None, + }, + ], + ) + + Band.delete().where(Band.name == "Rustaceans").run_sync() + + bands = Band.select().run_sync() + self.assertEqual( + bands, + [ + { + "id": 1, + "regrettable_column_name": "Pythonistas", + "popularity": 1000, + "manager_fk": None, + } + ], + ) + + @engines_only("cockroach") + def test_delete_alt(self): + """ + Make sure delete queries work correctly. + """ + result = ( + Band.insert( + Band(name="Pythonistas", popularity=1000), + Band(name="Rustaceans", popularity=500), + ) + .returning(Band.id) + .run_sync() + ) + + bands = Band.select().run_sync() + self.assertEqual( + bands, + [ + { + "id": result[0]["id"], + "regrettable_column_name": "Pythonistas", + "popularity": 1000, + "manager_fk": None, + }, + { + "id": result[1]["id"], + "regrettable_column_name": "Rustaceans", + "popularity": 500, + "manager_fk": None, + }, + ], + ) + + Band.delete().where(Band.name == "Rustaceans").run_sync() + + bands = Band.select().run_sync() + self.assertEqual( + bands, + [ + { + "id": result[0]["id"], + "regrettable_column_name": "Pythonistas", + "popularity": 1000, + "manager_fk": None, + } + ], + ) diff --git a/tests/columns/test_defaults.py b/tests/columns/test_defaults.py index b8198d3c6..dbe9ee522 100644 --- a/tests/columns/test_defaults.py +++ b/tests/columns/test_defaults.py @@ -9,6 +9,7 @@ BigInt, Date, DateNow, + DoublePrecision, ForeignKey, Integer, Numeric, @@ -21,6 +22,8 @@ TimestampNow, Varchar, ) +from piccolo.columns.defaults.timestamp import TimestampCustom +from piccolo.columns.defaults.timestamptz import TimestamptzCustom from piccolo.table import Table @@ -34,34 +37,32 @@ def test_int(self): _type(default=0) _type(default=None, null=True) with self.assertRaises(ValueError): - _type(default="hello world") - with self.assertRaises(ValueError): - _type(default=None, null=False) + _type(default="hello world") # type: ignore def test_text(self): for _type in (Text, Varchar): _type(default="") _type(default=None, null=True) with self.assertRaises(ValueError): - _type(default=123) - with self.assertRaises(ValueError): - _type(default=None, null=False) + _type(default=123) # type: ignore def test_real(self): Real(default=0.0) Real(default=None, null=True) with self.assertRaises(ValueError): - Real(default="hello world") + Real(default="hello world") # type: ignore + + def test_double_precision(self): + DoublePrecision(default=0.0) + DoublePrecision(default=None, null=True) with self.assertRaises(ValueError): - Real(default=None, null=False) + DoublePrecision(default="hello world") # type: ignore def test_numeric(self): Numeric(default=decimal.Decimal(1.0)) Numeric(default=None, null=True) with self.assertRaises(ValueError): - Numeric(default="hello world") - with self.assertRaises(ValueError): - Numeric(default=None, null=False) + Numeric(default="hello world") # type: ignore def test_uuid(self): UUID(default=None, null=True) @@ -69,35 +70,27 @@ def test_uuid(self): UUID(default=uuid.uuid4()) with self.assertRaises(ValueError): UUID(default="hello world") - with self.assertRaises(ValueError): - UUID(default=None, null=False) def test_time(self): Time(default=None, null=True) Time(default=TimeNow()) Time(default=datetime.datetime.now().time()) with self.assertRaises(ValueError): - Time(default="hello world") - with self.assertRaises(ValueError): - Time(default=None, null=False) + Time(default="hello world") # type: ignore def test_date(self): Date(default=None, null=True) Date(default=DateNow()) Date(default=datetime.datetime.now().date()) with self.assertRaises(ValueError): - Date(default="hello world") - with self.assertRaises(ValueError): - Date(default=None, null=False) + Date(default="hello world") # type: ignore def test_timestamp(self): Timestamp(default=None, null=True) Timestamp(default=TimestampNow()) Timestamp(default=datetime.datetime.now()) with self.assertRaises(ValueError): - Timestamp(default="hello world") - with self.assertRaises(ValueError): - Timestamp(default=None, null=False) + Timestamp(default="hello world") # type: ignore def test_foreignkey(self): class MyTable(Table): @@ -107,5 +100,35 @@ class MyTable(Table): ForeignKey(references=MyTable, default=1) with self.assertRaises(ValueError): ForeignKey(references=MyTable, default="hello world") - with self.assertRaises(ValueError): - ForeignKey(references=MyTable, default=None, null=False) + + +class TestDatetime(TestCase): + + def test_datetime(self): + """ + Make sure we can create a `TimestampCustom` / `TimestamptzCustom` from + a datetime, and then convert it back into the same datetime again. + + https://github.com/piccolo-orm/piccolo/issues/1169 + + """ + datetime_obj = datetime.datetime( + year=2025, + month=1, + day=30, + hour=12, + minute=10, + second=15, + microsecond=100, + ) + + self.assertEqual( + TimestampCustom.from_datetime(datetime_obj).datetime, + datetime_obj, + ) + + datetime_obj = datetime_obj.astimezone(tz=datetime.timezone.utc) + self.assertEqual( + TimestamptzCustom.from_datetime(datetime_obj).datetime, + datetime_obj, + ) diff --git a/tests/columns/test_double_precision.py b/tests/columns/test_double_precision.py new file mode 100644 index 000000000..e29a0e134 --- /dev/null +++ b/tests/columns/test_double_precision.py @@ -0,0 +1,20 @@ +from piccolo.columns.column_types import DoublePrecision +from piccolo.table import Table +from piccolo.testing.test_case import TableTest + + +class MyTable(Table): + column_a = DoublePrecision() + + +class TestDoublePrecision(TableTest): + tables = [MyTable] + + def test_creation(self): + row = MyTable(column_a=1.23) + row.save().run_sync() + + _row = MyTable.objects().first().run_sync() + assert _row is not None + self.assertEqual(type(_row.column_a), float) + self.assertAlmostEqual(_row.column_a, 1.23) diff --git a/tests/columns/test_foreignkey.py b/tests/columns/test_foreignkey.py deleted file mode 100644 index 90a2c4d7e..000000000 --- a/tests/columns/test_foreignkey.py +++ /dev/null @@ -1,203 +0,0 @@ -import time -from unittest import TestCase - -from piccolo.columns import Column, ForeignKey, LazyTableReference, Varchar -from piccolo.table import Table -from tests.base import DBTestCase -from tests.example_app.tables import Band, Manager - - -class Manager1(Table, tablename="manager"): - name = Varchar() - manager = ForeignKey("self", null=True) - - -class Band1(Table, tablename="band"): - manager = ForeignKey(references=Manager1) - - -class Band2(Table, tablename="band"): - manager = ForeignKey(references="Manager1") - - -class Band3(Table, tablename="band"): - manager = ForeignKey( - references=LazyTableReference( - table_class_name="Manager1", - module_path="tests.columns.test_foreignkey", - ) - ) - - -class Band4(Table, tablename="band"): - manager = ForeignKey(references="tests.columns.test_foreignkey.Manager1") - - -class TestForeignKeySelf(TestCase): - """ - Test that ForeignKey columns can be created with references to the parent - table. - """ - - def setUp(self): - Manager1.create_table().run_sync() - - def test_foreign_key_self(self): - manager = Manager1(name="Mr Manager") - manager.save().run_sync() - - worker = Manager1(name="Mr Worker", manager=manager.id) - worker.save().run_sync() - - response = ( - Manager1.select(Manager1.name, Manager1.manager.name) - .order_by(Manager1.name) - .run_sync() - ) - self.assertEqual( - response, - [ - {"name": "Mr Manager", "manager.name": None}, - {"name": "Mr Worker", "manager.name": "Mr Manager"}, - ], - ) - - def tearDown(self): - Manager1.alter().drop_table().run_sync() - - -class TestForeignKeyString(TestCase): - """ - Test that ForeignKey columns can be created with a `references` argument - set as a string value. - """ - - def setUp(self): - Manager1.create_table().run_sync() - - def test_foreign_key_string(self): - Band2.create_table().run_sync() - self.assertEqual( - Band2.manager._foreign_key_meta.resolved_references, - Manager1, - ) - Band2.alter().drop_table().run_sync() - - Band4.create_table().run_sync() - self.assertEqual( - Band4.manager._foreign_key_meta.resolved_references, - Manager1, - ) - Band4.alter().drop_table().run_sync() - - def tearDown(self): - Manager1.alter().drop_table().run_sync() - - -class TestForeignKeyRelativeError(TestCase): - def test_foreign_key_relative_error(self): - """ - Make sure that a references argument which contains a relative module - isn't allowed. - """ - with self.assertRaises(ValueError) as manager: - - class BandRelative(Table, tablename="band"): - manager = ForeignKey("..example_app.tables.Manager", null=True) - - self.assertEqual( - manager.exception.__str__(), "Relative imports aren't allowed" - ) - - -class TestReferences(TestCase): - def test_foreign_key_references(self): - """ - Make sure foreign key references are stored correctly on the table - which is the target of the ForeignKey. - """ - self.assertEqual(len(Manager1._meta.foreign_key_references), 5) - - self.assertTrue(Band1.manager in Manager._meta.foreign_key_references) - self.assertTrue(Band2.manager in Manager._meta.foreign_key_references) - self.assertTrue(Band3.manager in Manager._meta.foreign_key_references) - self.assertTrue(Band4.manager in Manager._meta.foreign_key_references) - self.assertTrue( - Manager1.manager in Manager1._meta.foreign_key_references - ) - - -class TestLazyTableReference(TestCase): - def test_lazy_reference_to_app(self): - """ - Make sure a LazyTableReference to a Table within a Piccolo app works. - """ - reference = LazyTableReference( - table_class_name="Manager", app_name="example_app" - ) - self.assertTrue(reference.resolve() is Manager) - - -class TestAttributeAccess(TestCase): - def test_attribute_access(self): - """ - Make sure that attribute access still works correctly with lazy - references. - """ - self.assertTrue(isinstance(Band1.manager.name, Varchar)) - self.assertTrue(isinstance(Band2.manager.name, Varchar)) - self.assertTrue(isinstance(Band3.manager.name, Varchar)) - self.assertTrue(isinstance(Band4.manager.name, Varchar)) - - def test_recursion_limit(self): - """ - When a table has a ForeignKey to itself, an Exception should be raised - if the call chain is too large. - """ - # Should be fine: - column: Column = Manager1.manager.name - self.assertTrue(len(column._meta.call_chain), 1) - self.assertTrue(isinstance(column, Varchar)) - - with self.assertRaises(Exception): - Manager1.manager.manager.manager.manager.manager.manager.manager.manager.manager.manager.manager.name # noqa - - def test_recursion_time(self): - """ - Make sure that a really large call chain doesn't take too long. - """ - start = time.time() - Manager1.manager.manager.manager.manager.manager.manager.name - end = time.time() - self.assertTrue(end - start < 1.0) - - -class TestAllColumns(DBTestCase): - def setUp(self): - Manager.create_table().run_sync() - manager = Manager(name="Guido") - manager.save().run_sync() - - Band.create_table().run_sync() - Band(manager=manager, name="Pythonistas").save().run_sync() - - def tearDown(self): - Band.alter().drop_table().run_sync() - Manager.alter().drop_table().run_sync() - - def test_all_columns(self): - """ - Make sure you can retrieve all columns from a related table, without - explicitly specifying them. - """ - result = Band.select(Band.name, *Band.manager.all_columns()).run_sync() - self.assertEqual( - result, - [ - { - "name": "Pythonistas", - "manager.id": 1, - "manager.name": "Guido", - } - ], - ) diff --git a/tests/columns/test_get_sql_value.py b/tests/columns/test_get_sql_value.py new file mode 100644 index 000000000..9a5d1c7d8 --- /dev/null +++ b/tests/columns/test_get_sql_value.py @@ -0,0 +1,66 @@ +import datetime +from unittest import TestCase + +from tests.base import engines_only +from tests.example_apps.music.tables import Band + + +@engines_only("postgres", "cockroach") +class TestArrayPostgres(TestCase): + + def test_string(self): + self.assertEqual( + Band.name.get_sql_value(["a", "b", "c"]), + '\'{"a","b","c"}\'', + ) + + def test_int(self): + self.assertEqual( + Band.name.get_sql_value([1, 2, 3]), + "'{1,2,3}'", + ) + + def test_nested(self): + self.assertEqual( + Band.name.get_sql_value([1, 2, 3, [4, 5, 6]]), + "'{1,2,3,{4,5,6}}'", + ) + + def test_time(self): + self.assertEqual( + Band.name.get_sql_value([datetime.time(hour=8, minute=0)]), + "'{\"08:00:00\"}'", + ) + + +@engines_only("sqlite") +class TestArraySQLite(TestCase): + """ + Note, we use ``.replace(" ", "")`` because we serialise arrays using + Python's json library, and there is inconsistency between Python versions + (some output ``["a", "b", "c"]``, and others ``["a","b","c"]``). + """ + + def test_string(self): + self.assertEqual( + Band.name.get_sql_value(["a", "b", "c"]).replace(" ", ""), + '\'["a","b","c"]\'', + ) + + def test_int(self): + self.assertEqual( + Band.name.get_sql_value([1, 2, 3]).replace(" ", ""), + "'[1,2,3]'", + ) + + def test_nested(self): + self.assertEqual( + Band.name.get_sql_value([1, 2, 3, [4, 5, 6]]).replace(" ", ""), + "'[1,2,3,[4,5,6]]'", + ) + + def test_time(self): + self.assertEqual( + Band.name.get_sql_value([datetime.time(hour=8, minute=0)]), + "'[\"08:00:00\"]'", + ) diff --git a/tests/columns/test_integer.py b/tests/columns/test_integer.py new file mode 100644 index 000000000..fe42aaf18 --- /dev/null +++ b/tests/columns/test_integer.py @@ -0,0 +1,32 @@ +from piccolo.columns.column_types import Integer +from piccolo.table import Table +from piccolo.testing.test_case import AsyncTableTest +from tests.base import sqlite_only + + +class MyTable(Table): + integer = Integer() + + +@sqlite_only +class TestInteger(AsyncTableTest): + tables = [MyTable] + + async def test_large_integer(self): + """ + Make sure large integers can be inserted and retrieved correctly. + + There was a bug with this in SQLite: + + https://github.com/piccolo-orm/piccolo/issues/1127 + + """ + integer = 625757527765811240 + + row = MyTable(integer=integer) + await row.save() + + _row = MyTable.objects().first().run_sync() + assert _row is not None + + self.assertEqual(_row.integer, integer) diff --git a/tests/columns/test_interval.py b/tests/columns/test_interval.py index 253fd09db..484038003 100644 --- a/tests/columns/test_interval.py +++ b/tests/columns/test_interval.py @@ -1,9 +1,9 @@ import datetime -from unittest import TestCase from piccolo.columns.column_types import Interval from piccolo.columns.defaults.interval import IntervalCustom from piccolo.table import Table +from piccolo.testing.test_case import TableTest class MyTable(Table): @@ -14,12 +14,8 @@ class MyTableDefault(Table): interval = Interval(default=IntervalCustom(days=1)) -class TestInterval(TestCase): - def setUp(self): - MyTable.create_table().run_sync() - - def tearDown(self): - MyTable.alter().drop_table().run_sync() +class TestInterval(TableTest): + tables = [MyTable] def test_interval(self): # Test a range of different timedeltas @@ -48,6 +44,7 @@ def test_interval(self): .first() .run_sync() ) + assert result is not None self.assertEqual(result.interval, interval) def test_interval_where_clause(self): @@ -64,7 +61,7 @@ def test_interval_where_clause(self): .first() .run_sync() ) - self.assertTrue(result is not None) + self.assertIsNotNone(result) result = ( MyTable.objects() @@ -72,7 +69,7 @@ def test_interval_where_clause(self): .first() .run_sync() ) - self.assertTrue(result is not None) + self.assertIsNotNone(result) result = ( MyTable.objects() @@ -80,7 +77,7 @@ def test_interval_where_clause(self): .first() .run_sync() ) - self.assertTrue(result is not None) + self.assertIsNotNone(result) result = ( MyTable.exists() @@ -90,16 +87,13 @@ def test_interval_where_clause(self): self.assertTrue(result) -class TestIntervalDefault(TestCase): - def setUp(self): - MyTableDefault.create_table().run_sync() - - def tearDown(self): - MyTableDefault.alter().drop_table().run_sync() +class TestIntervalDefault(TableTest): + tables = [MyTableDefault] def test_interval(self): row = MyTableDefault() row.save().run_sync() result = MyTableDefault.objects().first().run_sync() - self.assertTrue(result.interval.days == 1) + assert result is not None + self.assertEqual(result.interval.days, 1) diff --git a/tests/columns/test_json.py b/tests/columns/test_json.py index 16663f56b..19669c61b 100644 --- a/tests/columns/test_json.py +++ b/tests/columns/test_json.py @@ -1,7 +1,6 @@ -from unittest import TestCase - from piccolo.columns.column_types import JSON from piccolo.table import Table +from piccolo.testing.test_case import TableTest class MyTable(Table): @@ -20,12 +19,8 @@ class MyTableDefault(Table): json_none = JSON(default=None, null=True) -class TestJSONSave(TestCase): - def setUp(self): - MyTable.create_table().run_sync() - - def tearDown(self): - MyTable.alter().drop_table().run_sync() +class TestJSONSave(TableTest): + tables = [MyTable] def test_json_string(self): """ @@ -34,11 +29,11 @@ def test_json_string(self): row = MyTable(json='{"a": 1}') row.save().run_sync() + row_from_db = MyTable.select(MyTable.json).first().run_sync() + assert row_from_db is not None + self.assertEqual( - MyTable.select(MyTable.json) - .first() - .run_sync()["json"] - .replace(" ", ""), + row_from_db["json"].replace(" ", ""), '{"a":1}', ) @@ -49,21 +44,17 @@ def test_json_object(self): row = MyTable(json={"a": 1}) row.save().run_sync() + row_from_db = MyTable.select(MyTable.json).first().run_sync() + assert row_from_db is not None + self.assertEqual( - MyTable.select(MyTable.json) - .first() - .run_sync()["json"] - .replace(" ", ""), + row_from_db["json"].replace(" ", ""), '{"a":1}', ) -class TestJSONDefault(TestCase): - def setUp(self): - MyTableDefault.create_table().run_sync() - - def tearDown(self): - MyTableDefault.alter().drop_table().run_sync() +class TestJSONDefault(TableTest): + tables = [MyTableDefault] def test_json_default(self): row = MyTableDefault() @@ -78,22 +69,17 @@ def test_json_default(self): def test_invalid_default(self): with self.assertRaises(ValueError): for value in ("a", 1, ("x", "y", "z")): - JSON(default=value) - + JSON(default=value) # type: ignore -class TestJSONInsert(TestCase): - def setUp(self): - MyTable.create_table().run_sync() - def tearDown(self): - MyTable.alter().drop_table().run_sync() +class TestJSONInsert(TableTest): + tables = [MyTable] def check_response(self): + row = MyTable.select(MyTable.json).first().run_sync() + assert row is not None self.assertEqual( - MyTable.select(MyTable.json) - .first() - .run_sync()["json"] - .replace(" ", ""), + row["json"].replace(" ", ""), '{"message":"original"}', ) @@ -113,23 +99,18 @@ def test_json_object(self): MyTable.insert(row).run_sync() -class TestJSONUpdate(TestCase): - def setUp(self): - MyTable.create_table().run_sync() - - def tearDown(self): - MyTable.alter().drop_table().run_sync() +class TestJSONUpdate(TableTest): + tables = [MyTable] def add_row(self): row = MyTable(json={"message": "original"}) row.save().run_sync() def check_response(self): + row = MyTable.select(MyTable.json).first().run_sync() + assert row is not None self.assertEqual( - MyTable.select(MyTable.json) - .first() - .run_sync()["json"] - .replace(" ", ""), + row["json"].replace(" ", ""), '{"message":"updated"}', ) @@ -138,7 +119,9 @@ def test_json_update_string(self): Test updating a JSON field using a string. """ self.add_row() - MyTable.update({MyTable.json: '{"message": "updated"}'}).run_sync() + MyTable.update( + {MyTable.json: '{"message": "updated"}'}, force=True + ).run_sync() self.check_response() def test_json_update_object(self): @@ -146,5 +129,7 @@ def test_json_update_object(self): Test updating a JSON field using an object. """ self.add_row() - MyTable.update({MyTable.json: {"message": "updated"}}).run_sync() + MyTable.update( + {MyTable.json: {"message": "updated"}}, force=True + ).run_sync() self.check_response() diff --git a/tests/columns/test_jsonb.py b/tests/columns/test_jsonb.py index b5bfe923b..f38c0de05 100644 --- a/tests/columns/test_jsonb.py +++ b/tests/columns/test_jsonb.py @@ -1,76 +1,286 @@ -from unittest import TestCase - -from piccolo.columns.column_types import JSONB +from piccolo.columns.column_types import JSONB, ForeignKey, Varchar from piccolo.table import Table +from piccolo.testing.test_case import AsyncTableTest, TableTest +from tests.base import engines_only, engines_skip -from ..base import postgres_only +class RecordingStudio(Table): + name = Varchar() + facilities = JSONB(null=True) -class MyTable(Table): - json = JSONB() +class Instrument(Table): + name = Varchar() + studio = ForeignKey(RecordingStudio) -@postgres_only -class TestJSONB(TestCase): - def setUp(self): - MyTable.create_table().run_sync() - def tearDown(self): - MyTable.alter().drop_table().run_sync() +@engines_only("postgres", "cockroach") +class TestJSONB(TableTest): + tables = [RecordingStudio, Instrument] def test_json(self): """ Test storing a valid JSON string. """ - row = MyTable(json='{"a": 1}') + row = RecordingStudio( + name="Abbey Road", facilities='{"mixing_desk": true}' + ) row.save().run_sync() - self.assertEqual(row.json, '{"a": 1}') + self.assertEqual(row.facilities, '{"mixing_desk": true}') - def test_arrow(self): + @engines_skip("cockroach") + def test_raw(self): """ - Test using the arrow function to retrieve a subset of the JSON. + Make sure raw queries convert the Python value into a JSON string. + """ + RecordingStudio.raw( + "INSERT INTO recording_studio (name, facilities) VALUES ({}, {})", + "Abbey Road", + '{"mixing_desk": true}', + ).run_sync() + + self.assertEqual( + RecordingStudio.select().run_sync(), + [ + { + "id": 1, + "name": "Abbey Road", + "facilities": '{"mixing_desk": true}', + } + ], + ) + + @engines_only("cockroach") + def test_raw_alt(self): + """ + Make sure raw queries convert the Python value into a JSON string. + """ + result = RecordingStudio.raw( + "INSERT INTO recording_studio (name, facilities) VALUES ({}, {}) returning id", # noqa: E501 + "Abbey Road", + '{"mixing_desk": true}', + ).run_sync() + + self.assertEqual( + RecordingStudio.select().run_sync(), + [ + { + "id": result[0]["id"], + "name": "Abbey Road", + "facilities": '{"mixing_desk": true}', + } + ], + ) + + def test_where(self): + """ + Test using the where clause to match a subset of rows. + """ + RecordingStudio.insert( + RecordingStudio( + name="Abbey Road", facilities={"mixing_desk": True} + ), + RecordingStudio(name="ABC Studio", facilities=None), + ).run_sync() + + self.assertEqual( + RecordingStudio.select(RecordingStudio.name) + .where(RecordingStudio.facilities == {"mixing_desk": True}) + .run_sync(), + [{"name": "Abbey Road"}], + ) + + self.assertEqual( + RecordingStudio.select(RecordingStudio.name) + .where(RecordingStudio.facilities == '{"mixing_desk": true}') + .run_sync(), + [{"name": "Abbey Road"}], + ) + + self.assertEqual( + RecordingStudio.select(RecordingStudio.name) + .where(RecordingStudio.facilities.is_null()) + .run_sync(), + [{"name": "ABC Studio"}], + ) + + self.assertEqual( + RecordingStudio.select(RecordingStudio.name) + .where(RecordingStudio.facilities.is_not_null()) + .run_sync(), + [{"name": "Abbey Road"}], + ) + + def test_as_alias_join(self): """ - MyTable(json='{"a": 1}').save().run_sync() - row = MyTable.select(MyTable.json.arrow("a")).first().run_sync() - self.assertEqual(row["?column?"], "1") + Make sure that ``as_alias`` performs correctly when used via a join. + """ + studio = ( + RecordingStudio.objects() + .create(name="Abbey Road", facilities={"mixing_desk": True}) + .run_sync() + ) + + Instrument.objects().create(name="Guitar", studio=studio).run_sync() + + response = ( + Instrument.select( + Instrument.name, + Instrument.studio.facilities.as_alias("studio_facilities"), + ) + .output(load_json=True) + .run_sync() + ) + + self.assertListEqual( + response, + [{"name": "Guitar", "studio_facilities": {"mixing_desk": True}}], + ) + + +@engines_only("postgres", "cockroach") +class TestArrow(AsyncTableTest): + tables = [RecordingStudio, Instrument] + + async def insert_row(self): + await RecordingStudio( + name="Abbey Road", facilities='{"mixing_desk": true}' + ).save() - def test_arrow_as_alias(self): + async def test_arrow(self): """ Test using the arrow function to retrieve a subset of the JSON. """ - MyTable(json='{"a": 1}').save().run_sync() - row = ( - MyTable.select(MyTable.json.arrow("a").as_alias("a")) + await self.insert_row() + + row = await RecordingStudio.select( + RecordingStudio.facilities.arrow("mixing_desk") + ).first() + assert row is not None + self.assertEqual(row["facilities"], "true") + + row = await ( + RecordingStudio.select( + RecordingStudio.facilities.arrow("mixing_desk") + ) + .output(load_json=True) .first() - .run_sync() ) - self.assertEqual(row["a"], "1") + assert row is not None + self.assertEqual(row["facilities"], True) + + async def test_arrow_as_alias(self): + """ + Test using the arrow function to retrieve a subset of the JSON. + """ + await self.insert_row() + + row = await RecordingStudio.select( + RecordingStudio.facilities.arrow("mixing_desk").as_alias( + "mixing_desk" + ) + ).first() + assert row is not None + self.assertEqual(row["mixing_desk"], "true") + + async def test_square_brackets(self): + """ + Make sure we can use square brackets instead of calling ``arrow`` + explicitly. + """ + await self.insert_row() + + row = await RecordingStudio.select( + RecordingStudio.facilities["mixing_desk"].as_alias("mixing_desk") + ).first() + assert row is not None + self.assertEqual(row["mixing_desk"], "true") + + async def test_multiple_levels_deep(self): + """ + Make sure elements can be extracted multiple levels deep, and using + array indexes. + """ + await RecordingStudio( + name="Abbey Road", + facilities={ + "technicians": [ + {"name": "Alice Jones"}, + {"name": "Bob Williams"}, + ] + }, + ).save() + + response = await RecordingStudio.select( + RecordingStudio.facilities["technicians"][0]["name"].as_alias( + "technician_name" + ) + ).output(load_json=True) + assert response is not None + self.assertListEqual(response, [{"technician_name": "Alice Jones"}]) - def test_arrow_where(self): + async def test_arrow_where(self): """ Make sure the arrow function can be used within a WHERE clause. """ - MyTable(json='{"a": 1}').save().run_sync() + await self.insert_row() + self.assertEqual( - MyTable.count().where(MyTable.json.arrow("a") == "1").run_sync(), 1 + await RecordingStudio.count().where( + RecordingStudio.facilities.arrow("mixing_desk").eq(True) + ), + 1, ) self.assertEqual( - MyTable.count().where(MyTable.json.arrow("a") == "2").run_sync(), 0 + await RecordingStudio.count().where( + RecordingStudio.facilities.arrow("mixing_desk").eq(False) + ), + 0, ) - def test_arrow_first(self): + async def test_arrow_first(self): """ Make sure the arrow function can be used with the first clause. """ - MyTable.insert( - MyTable(json='{"a": 1}'), - MyTable(json='{"b": 2}'), - ).run_sync() + await RecordingStudio.insert( + RecordingStudio(facilities='{"mixing_desk": true}'), + RecordingStudio(facilities='{"mixing_desk": false}'), + ) self.assertEqual( - MyTable.select(MyTable.json.arrow("a").as_alias("json")) - .first() - .run_sync(), - {"json": "1"}, + await RecordingStudio.select( + RecordingStudio.facilities.arrow("mixing_desk").as_alias( + "mixing_desk" + ) + ).first(), + {"mixing_desk": "true"}, ) + + +@engines_only("postgres", "cockroach") +class TestFromPath(AsyncTableTest): + + tables = [RecordingStudio, Instrument] + + async def test_from_path(self): + """ + Make sure ``from_path`` can be used for complex nested data. + """ + await RecordingStudio( + name="Abbey Road", + facilities={ + "technicians": [ + {"name": "Alice Jones"}, + {"name": "Bob Williams"}, + ] + }, + ).save() + + response = await RecordingStudio.select( + RecordingStudio.facilities.from_path( + ["technicians", 0, "name"] + ).as_alias("technician_name") + ).output(load_json=True) + assert response is not None + self.assertListEqual(response, [{"technician_name": "Alice Jones"}]) diff --git a/tests/columns/test_numeric.py b/tests/columns/test_numeric.py index 1482dc8ad..22c650c70 100644 --- a/tests/columns/test_numeric.py +++ b/tests/columns/test_numeric.py @@ -1,8 +1,8 @@ from decimal import Decimal -from unittest import TestCase from piccolo.columns.column_types import Numeric from piccolo.table import Table +from piccolo.testing.test_case import TableTest class MyTable(Table): @@ -10,21 +10,18 @@ class MyTable(Table): column_b = Numeric(digits=(3, 2)) -class TestNumeric(TestCase): - def setUp(self): - MyTable.create_table().run_sync() - - def tearDown(self): - MyTable.alter().drop_table().run_sync() +class TestNumeric(TableTest): + tables = [MyTable] def test_creation(self): row = MyTable(column_a=Decimal(1.23), column_b=Decimal(1.23)) row.save().run_sync() _row = MyTable.objects().first().run_sync() + assert _row is not None - self.assertTrue(type(_row.column_a) == Decimal) - self.assertTrue(type(_row.column_b) == Decimal) + self.assertEqual(type(_row.column_a), Decimal) + self.assertEqual(type(_row.column_b), Decimal) self.assertAlmostEqual(_row.column_a, Decimal(1.23)) - self.assertEqual(_row.column_b, Decimal("1.23")) + self.assertAlmostEqual(_row.column_b, Decimal("1.23")) diff --git a/tests/columns/test_primary_key.py b/tests/columns/test_primary_key.py index 42629a495..86868a2c8 100644 --- a/tests/columns/test_primary_key.py +++ b/tests/columns/test_primary_key.py @@ -1,8 +1,14 @@ import uuid -from unittest import TestCase -from piccolo.columns.column_types import UUID, ForeignKey, Serial, Varchar +from piccolo.columns.column_types import ( + UUID, + BigSerial, + ForeignKey, + Serial, + Varchar, +) from piccolo.table import Table +from piccolo.testing.test_case import TableTest class MyTableDefaultPrimaryKey(Table): @@ -14,72 +20,73 @@ class MyTablePrimaryKeySerial(Table): name = Varchar() -class MyTablePrimaryKeyUUID(Table): - id = UUID(null=False, primary_key=True) +class MyTablePrimaryKeyBigSerial(Table): + pk = BigSerial(null=False, primary_key=True) name = Varchar() -class TestPrimaryKeyDefault(TestCase): - def setUp(self): - MyTableDefaultPrimaryKey.create_table().run_sync() +class MyTablePrimaryKeyUUID(Table): + pk = UUID(null=False, primary_key=True) + name = Varchar() + - def tearDown(self): - MyTableDefaultPrimaryKey.alter().drop_table().run_sync() +class TestPrimaryKeyDefault(TableTest): + tables = [MyTableDefaultPrimaryKey] def test_return_type(self): row = MyTableDefaultPrimaryKey() row.save().run_sync() self.assertIsInstance(row._meta.primary_key, Serial) + self.assertIsInstance(row["id"], int) -class TestPrimaryKeyInteger(TestCase): - def setUp(self): - MyTablePrimaryKeySerial.create_table().run_sync() - - def tearDown(self): - MyTablePrimaryKeySerial.alter().drop_table().run_sync() +class TestPrimaryKeyInteger(TableTest): + tables = [MyTablePrimaryKeySerial] def test_return_type(self): row = MyTablePrimaryKeySerial() - result = row.save().run_sync()[0] + row.save().run_sync() + + self.assertIsInstance(row._meta.primary_key, Serial) + self.assertIsInstance(row["pk"], int) - self.assertIsInstance(result["pk"], int) +class TestPrimaryKeyBigSerial(TableTest): + tables = [MyTablePrimaryKeyBigSerial] + + def test_return_type(self): + row = MyTablePrimaryKeyBigSerial() + row.save().run_sync() + + self.assertIsInstance(row._meta.primary_key, BigSerial) + self.assertIsInstance(row["pk"], int) -class TestPrimaryKeyUUID(TestCase): - def setUp(self): - MyTablePrimaryKeyUUID.create_table().run_sync() - def tearDown(self): - MyTablePrimaryKeyUUID.alter().drop_table().run_sync() +class TestPrimaryKeyUUID(TableTest): + tables = [MyTablePrimaryKeyUUID] def test_return_type(self): row = MyTablePrimaryKeyUUID() row.save().run_sync() - self.assertIsInstance(row.id, uuid.UUID) + self.assertIsInstance(row._meta.primary_key, UUID) + self.assertIsInstance(row["pk"], uuid.UUID) class Manager(Table): - pk = UUID(primary=True, key=True) + pk = UUID(primary_key=True) name = Varchar() class Band(Table): - pk = UUID(primary=True, key=True) + pk = UUID(primary_key=True) name = Varchar() manager = ForeignKey(Manager) -class TestPrimaryKeyQueries(TestCase): - def setUp(self): - Manager.create_table().run_sync() - Band.create_table().run_sync() - - def tearDown(self): - Band.alter().drop_table().run_sync() - Manager.alter().drop_table().run_sync() +class TestPrimaryKeyQueries(TableTest): + tables = [Manager, Band] def test_primary_key_queries(self): """ @@ -104,13 +111,14 @@ def test_primary_key_queries(self): ) manager_dict = Manager.select().first().run_sync() + assert manager_dict is not None self.assertEqual( [i for i in manager_dict.keys()], ["pk", "name"], ) - self.assertTrue(isinstance(manager_dict["pk"], uuid.UUID)) + self.assertIsInstance(manager_dict["pk"], uuid.UUID) ####################################################################### # Make sure we can create rows with foreign keys to tables with a @@ -122,18 +130,20 @@ def test_primary_key_queries(self): band.save().run_sync() band_dict = Band.select().first().run_sync() + assert band_dict is not None self.assertEqual( [i for i in band_dict.keys()], ["pk", "name", "manager"] ) - self.assertTrue(isinstance(band_dict["pk"], uuid.UUID)) - self.assertTrue(isinstance(band_dict["manager"], uuid.UUID)) + self.assertIsInstance(band_dict["pk"], uuid.UUID) + self.assertIsInstance(band_dict["manager"], uuid.UUID) ####################################################################### # Make sure foreign key values can be specified as the primary key's # type (i.e. `uuid.UUID`). manager = Manager.objects().first().run_sync() + assert manager is not None band_2 = Band(manager=manager.pk, name="Pythonistas 2") band_2.save().run_sync() @@ -149,9 +159,10 @@ def test_primary_key_queries(self): ####################################################################### # Make sure `get_related` works - self.assertEqual( - band_2.get_related(Band.manager).run_sync().pk, manager.pk - ) + manager_from_db = band_2.get_related(Band.manager).run_sync() + assert manager_from_db is not None + + self.assertEqual(manager_from_db.pk, manager.pk) ####################################################################### # Make sure `remove` works diff --git a/tests/columns/test_readable.py b/tests/columns/test_readable.py index e7225fbd1..d04036147 100644 --- a/tests/columns/test_readable.py +++ b/tests/columns/test_readable.py @@ -1,8 +1,7 @@ -import unittest - from piccolo import columns from piccolo.columns.readable import Readable from piccolo.table import Table +from piccolo.testing.test_case import TableTest class MyTable(Table): @@ -16,14 +15,13 @@ def get_readable(cls) -> Readable: ) -class TestReadable(unittest.TestCase): +class TestReadable(TableTest): + tables = [MyTable] + def setUp(self): - MyTable.create_table().run_sync() + super().setUp() MyTable(first_name="Guido", last_name="van Rossum").save().run_sync() def test_readable(self): response = MyTable.select(MyTable.get_readable()).run_sync() - self.assertTrue(response[0]["readable"] == "Guido van Rossum") - - def tearDown(self): - MyTable.alter().drop_table().run_sync() + self.assertEqual(response[0]["readable"], "Guido van Rossum") diff --git a/tests/columns/test_real.py b/tests/columns/test_real.py index 09bcdeb40..a2cef5a75 100644 --- a/tests/columns/test_real.py +++ b/tests/columns/test_real.py @@ -1,24 +1,20 @@ -from unittest import TestCase - from piccolo.columns.column_types import Real from piccolo.table import Table +from piccolo.testing.test_case import TableTest class MyTable(Table): column_a = Real() -class TestReal(TestCase): - def setUp(self): - MyTable.create_table().run_sync() - - def tearDown(self): - MyTable.alter().drop_table().run_sync() +class TestReal(TableTest): + tables = [MyTable] def test_creation(self): row = MyTable(column_a=1.23) row.save().run_sync() _row = MyTable.objects().first().run_sync() - self.assertTrue(type(_row.column_a) == float) + assert _row is not None + self.assertEqual(type(_row.column_a), float) self.assertAlmostEqual(_row.column_a, 1.23) diff --git a/tests/columns/test_reference.py b/tests/columns/test_reference.py index 391592a84..8a10b6207 100644 --- a/tests/columns/test_reference.py +++ b/tests/columns/test_reference.py @@ -2,12 +2,44 @@ Most of the tests for piccolo/columns/reference.py are covered in piccolo/columns/test_foreignkey.py """ + from unittest import TestCase +from piccolo.columns import ForeignKey, Varchar from piccolo.columns.reference import LazyTableReference +from piccolo.table import Table +from piccolo.testing.test_case import TableTest + + +class Band(Table): + manager: ForeignKey["Manager"] = ForeignKey( + LazyTableReference("Manager", module_path=__name__) + ) + name = Varchar() + + +class Manager(Table): + name = Varchar() + +class TestQueries(TableTest): + tables = [Band, Manager] -class TestLazyTableReference(TestCase): + def setUp(self): + super().setUp() + manager = Manager({Manager.name: "Guido"}) + manager.save().run_sync() + band = Band({Band.name: "Pythonistas", Band.manager: manager}) + band.save().run_sync() + + def test_select(self): + self.assertListEqual( + Band.select(Band.name, Band.manager._.name).run_sync(), + [{"name": "Pythonistas", "manager.name": "Guido"}], + ) + + +class TestInit(TestCase): def test_init(self): """ A ``LazyTableReference`` must be passed either an ``app_name`` or @@ -19,33 +51,35 @@ def test_init(self): with self.assertRaises(ValueError): LazyTableReference( table_class_name="Manager", - app_name="example_app", - module_path="tests.example_app.tables", + app_name="music", + module_path="tests.example_apps.music.tables", ) # Shouldn't raise exceptions: LazyTableReference( table_class_name="Manager", - app_name="example_app", + app_name="music", ) LazyTableReference( table_class_name="Manager", - module_path="tests.example_app.tables", + module_path="tests.example_apps.music.tables", ) + +class TestStr(TestCase): def test_str(self): self.assertEqual( LazyTableReference( table_class_name="Manager", - app_name="example_app", + app_name="music", ).__str__(), - "App example_app.Manager", + "App music.Manager", ) self.assertEqual( LazyTableReference( table_class_name="Manager", - module_path="tests.example_app.tables", + module_path="tests.example_apps.music.tables", ).__str__(), - "Module tests.example_app.tables.Manager", + "Module tests.example_apps.music.tables.Manager", ) diff --git a/tests/columns/test_reserved_column_names.py b/tests/columns/test_reserved_column_names.py new file mode 100644 index 000000000..1fc4a464e --- /dev/null +++ b/tests/columns/test_reserved_column_names.py @@ -0,0 +1,55 @@ +from piccolo.columns.column_types import Integer, Varchar +from piccolo.table import Table +from piccolo.testing.test_case import TableTest + + +class Concert(Table): + """ + ``order`` is a problematic name, as it clashes with a reserved SQL keyword: + + https://www.postgresql.org/docs/current/sql-keywords-appendix.html + + """ + + name = Varchar() + order = Integer() + + +class TestReservedColumnNames(TableTest): + """ + Make sure the table works as expected, even though it has a problematic + column name. + """ + + tables = [Concert] + + def test_common_operations(self): + # Save / Insert + concert = Concert(name="Royal Albert Hall", order=1) + concert.save().run_sync() + self.assertEqual( + Concert.select(Concert.order).run_sync(), + [{"order": 1}], + ) + + # Save / Update + concert.order = 2 + concert.save().run_sync() + self.assertEqual( + Concert.select(Concert.order).run_sync(), + [{"order": 2}], + ) + + # Update + Concert.update({Concert.order: 3}, force=True).run_sync() + self.assertEqual( + Concert.select(Concert.order).run_sync(), + [{"order": 3}], + ) + + # Delete + Concert.delete().where(Concert.order == 3).run_sync() + self.assertEqual( + Concert.select(Concert.order).run_sync(), + [], + ) diff --git a/tests/columns/test_smallint.py b/tests/columns/test_smallint.py index 45b8ff683..808562322 100644 --- a/tests/columns/test_smallint.py +++ b/tests/columns/test_smallint.py @@ -1,31 +1,26 @@ import os -from unittest import TestCase from piccolo.columns.column_types import SmallInt from piccolo.table import Table - -from ..base import postgres_only +from piccolo.testing.test_case import TableTest +from tests.base import engines_only class MyTable(Table): value = SmallInt() -@postgres_only -class TestSmallIntPostgres(TestCase): +@engines_only("postgres", "cockroach") +class TestSmallIntPostgres(TableTest): """ Make sure a SmallInt column in Postgres can only store small numbers. """ - def setUp(self): - MyTable.create_table().run_sync() - - def tearDown(self): - MyTable.alter().drop_table().run_sync() + tables = [MyTable] def _test_length(self): # Can store 2 bytes, but split between positive and negative values. - max_value = int(2 ** 16 / 2) - 1 + max_value = int(2**16 / 2) - 1 min_value = max_value * -1 print("Testing max value") diff --git a/tests/columns/test_time.py b/tests/columns/test_time.py index 56282b262..9fc48aaad 100644 --- a/tests/columns/test_time.py +++ b/tests/columns/test_time.py @@ -1,10 +1,11 @@ import datetime from functools import partial -from unittest import TestCase from piccolo.columns.column_types import Time from piccolo.columns.defaults.time import TimeNow from piccolo.table import Table +from piccolo.testing.test_case import TableTest +from tests.base import engines_skip class MyTable(Table): @@ -15,29 +16,24 @@ class MyTableDefault(Table): created_on = Time(default=TimeNow()) -class TestTime(TestCase): - def setUp(self): - MyTable.create_table().run_sync() - - def tearDown(self): - MyTable.alter().drop_table().run_sync() +class TestTime(TableTest): + tables = [MyTable] + @engines_skip("cockroach") def test_timestamp(self): created_on = datetime.datetime.now().time() row = MyTable(created_on=created_on) row.save().run_sync() result = MyTable.objects().first().run_sync() + assert result is not None self.assertEqual(result.created_on, created_on) -class TestTimeDefault(TestCase): - def setUp(self): - MyTableDefault.create_table().run_sync() - - def tearDown(self): - MyTableDefault.alter().drop_table().run_sync() +class TestTimeDefault(TableTest): + tables = [MyTableDefault] + @engines_skip("cockroach") def test_timestamp(self): created_on = datetime.datetime.now().time() row = MyTableDefault() @@ -46,7 +42,8 @@ def test_timestamp(self): _datetime = partial(datetime.datetime, year=2020, month=1, day=1) result = MyTableDefault.objects().first().run_sync() - self.assertTrue( + assert result is not None + self.assertLess( _datetime( hour=result.created_on.hour, minute=result.created_on.minute, @@ -56,6 +53,6 @@ def test_timestamp(self): hour=created_on.hour, minute=created_on.minute, second=created_on.second, - ) - < datetime.timedelta(seconds=1) + ), + datetime.timedelta(seconds=1), ) diff --git a/tests/columns/test_timestamp.py b/tests/columns/test_timestamp.py index 70ba3fa28..084da7c6c 100644 --- a/tests/columns/test_timestamp.py +++ b/tests/columns/test_timestamp.py @@ -1,9 +1,9 @@ import datetime -from unittest import TestCase from piccolo.columns.column_types import Timestamp from piccolo.columns.defaults.timestamp import TimestampNow from piccolo.table import Table +from piccolo.testing.test_case import TableTest class MyTable(Table): @@ -19,12 +19,8 @@ class MyTableDefault(Table): created_on = Timestamp(default=TimestampNow()) -class TestTimestamp(TestCase): - def setUp(self): - MyTable.create_table().run_sync() - - def tearDown(self): - MyTable.alter().drop_table().run_sync() +class TestTimestamp(TableTest): + tables = [MyTable] def test_timestamp(self): """ @@ -35,6 +31,7 @@ def test_timestamp(self): row.save().run_sync() result = MyTable.objects().first().run_sync() + assert result is not None self.assertEqual(result.created_on, created_on) def test_timezone_aware(self): @@ -45,12 +42,8 @@ def test_timezone_aware(self): Timestamp(default=datetime.datetime.now(tz=datetime.timezone.utc)) -class TestTimestampDefault(TestCase): - def setUp(self): - MyTableDefault.create_table().run_sync() - - def tearDown(self): - MyTableDefault.alter().drop_table().run_sync() +class TestTimestampDefault(TableTest): + tables = [MyTableDefault] def test_timestamp(self): """ @@ -61,7 +54,8 @@ def test_timestamp(self): row.save().run_sync() result = MyTableDefault.objects().first().run_sync() - self.assertTrue( - result.created_on - created_on < datetime.timedelta(seconds=1) + assert result is not None + self.assertLess( + result.created_on - created_on, datetime.timedelta(seconds=1) ) - self.assertTrue(result.created_on.tzinfo is None) + self.assertIsNone(result.created_on.tzinfo) diff --git a/tests/columns/test_timestamptz.py b/tests/columns/test_timestamptz.py index 3f66ab4dc..cf3528b9a 100644 --- a/tests/columns/test_timestamptz.py +++ b/tests/columns/test_timestamptz.py @@ -1,5 +1,4 @@ import datetime -from unittest import TestCase from dateutil import tz @@ -10,6 +9,7 @@ TimestamptzOffset, ) from piccolo.table import Table +from piccolo.testing.test_case import TableTest class MyTable(Table): @@ -34,12 +34,8 @@ class CustomTimezone(datetime.tzinfo): pass -class TestTimestamptz(TestCase): - def setUp(self): - MyTable.create_table().run_sync() - - def tearDown(self): - MyTable.alter().drop_table().run_sync() +class TestTimestamptz(TableTest): + tables = [MyTable] def test_timestamptz_timezone_aware(self): """ @@ -71,18 +67,15 @@ def test_timestamptz_timezone_aware(self): .first() .run_sync() ) + assert result is not None self.assertEqual(result.created_on, created_on) # The database converts it to UTC self.assertEqual(result.created_on.tzinfo, datetime.timezone.utc) -class TestTimestamptzDefault(TestCase): - def setUp(self): - MyTableDefault.create_table().run_sync() - - def tearDown(self): - MyTableDefault.alter().drop_table().run_sync() +class TestTimestamptzDefault(TableTest): + tables = [MyTableDefault] def test_timestamptz_default(self): """ @@ -93,6 +86,7 @@ def test_timestamptz_default(self): row.save().run_sync() result = MyTableDefault.objects().first().run_sync() + assert result is not None delta = result.created_on - created_on - self.assertTrue(delta < datetime.timedelta(seconds=1)) + self.assertLess(delta, datetime.timedelta(seconds=1)) self.assertEqual(result.created_on.tzinfo, datetime.timezone.utc) diff --git a/tests/columns/test_uuid.py b/tests/columns/test_uuid.py index 64a197fc6..3dcce88a1 100644 --- a/tests/columns/test_uuid.py +++ b/tests/columns/test_uuid.py @@ -1,20 +1,16 @@ import uuid -from unittest import TestCase from piccolo.columns.column_types import UUID from piccolo.table import Table +from piccolo.testing.test_case import TableTest class MyTable(Table): uuid = UUID() -class TestUUID(TestCase): - def setUp(self): - MyTable.create_table().run_sync() - - def tearDown(self): - MyTable.alter().drop_table().run_sync() +class TestUUID(TableTest): + tables = [MyTable] def test_return_type(self): row = MyTable() diff --git a/tests/columns/test_varchar.py b/tests/columns/test_varchar.py index cf258a9d1..c62a3a0fd 100644 --- a/tests/columns/test_varchar.py +++ b/tests/columns/test_varchar.py @@ -1,30 +1,24 @@ -from unittest import TestCase - from piccolo.columns.column_types import Varchar from piccolo.table import Table - -from ..base import postgres_only +from piccolo.testing.test_case import TableTest +from tests.base import engines_only class MyTable(Table): name = Varchar(length=10) -@postgres_only -class TestVarchar(TestCase): +@engines_only("postgres", "cockroach") +class TestVarchar(TableTest): """ SQLite doesn't enforce any constraints on max character length. https://www.sqlite.org/faq.html#q9 - Might consider enforncing this at the ORM level instead in the future. + Might consider enforcing this at the ORM level instead in the future. """ - def setUp(self): - MyTable.create_table().run_sync() - - def tearDown(self): - MyTable.alter().drop_table().run_sync() + tables = [MyTable] def test_length(self): row = MyTable(name="bob") diff --git a/tests/conf/example.py b/tests/conf/example.py new file mode 100644 index 000000000..ef4541627 --- /dev/null +++ b/tests/conf/example.py @@ -0,0 +1,13 @@ +""" +This file is used by test_apps.py to make sure we can exclude imported +``Table`` subclasses when using ``table_finder``. +""" + +from piccolo.apps.user.tables import BaseUser +from piccolo.columns.column_types import ForeignKey, Varchar +from piccolo.table import Table + + +class Musician(Table): + name = Varchar() + user = ForeignKey(BaseUser) diff --git a/tests/conf/test_apps.py b/tests/conf/test_apps.py index 737d8b6ca..e98c978a0 100644 --- a/tests/conf/test_apps.py +++ b/tests/conf/test_apps.py @@ -1,21 +1,41 @@ +from __future__ import annotations + +import pathlib +import tempfile from unittest import TestCase from piccolo.apps.user.tables import BaseUser -from piccolo.conf.apps import AppConfig, AppRegistry, table_finder - -from ..example_app.tables import Manager +from piccolo.conf.apps import ( + AppConfig, + AppRegistry, + Finder, + PiccoloConfUpdater, + table_finder, +) +from tests.example_apps.mega.tables import MegaTable, SmallTable +from tests.example_apps.music.tables import ( + Band, + Concert, + Instrument, + Manager, + Poster, + RecordingStudio, + Shirt, + Ticket, + Venue, +) class TestAppRegistry(TestCase): def test_get_app_config(self): app_registry = AppRegistry(apps=["piccolo.apps.user.piccolo_app"]) app_config = app_registry.get_app_config(app_name="user") - self.assertTrue(isinstance(app_config, AppConfig)) + self.assertIsInstance(app_config, AppConfig) def test_get_table_classes(self): app_registry = AppRegistry(apps=["piccolo.apps.user.piccolo_app"]) table_classes = app_registry.get_table_classes(app_name="user") - self.assertTrue(BaseUser in table_classes) + self.assertIn(BaseUser, table_classes) with self.assertRaises(ValueError): app_registry.get_table_classes(app_name="Foo") @@ -32,6 +52,29 @@ def test_duplicate_app_names(self): ] ) + def test_app_names_not_ending_piccolo_app(self): + """ + Should automatically add `.piccolo_app` to end. + """ + AppRegistry( + apps=[ + "piccolo.apps.user", + ] + ) + + def test_duplicate_app_names_with_auto_changed(self): + """ + Make sure duplicate app names are still detected when `piccolo_app` + is omitted from the end. + """ + with self.assertRaises(ValueError): + AppRegistry( + apps=[ + "piccolo.apps.user.piccolo_app", + "piccolo.apps.user", + ] + ) + def test_get_table_with_name(self): app_registry = AppRegistry(apps=["piccolo.apps.user.piccolo_app"]) table = app_registry.get_table_with_name( @@ -41,11 +84,21 @@ def test_get_table_with_name(self): class TestAppConfig(TestCase): + def test_pathlib(self): + """ + Make sure a ``pathlib.Path`` instance can be passed in as a + ``migrations_folder_path`` argument. + """ + config = AppConfig( + app_name="music", migrations_folder_path=pathlib.Path(__file__) + ) + self.assertEqual(config.resolved_migrations_folder_path, __file__) + def test_get_table_with_name(self): """ Register a table, then test retrieving it. """ - config = AppConfig(app_name="Music", migrations_folder_path="") + config = AppConfig(app_name="music", migrations_folder_path="") config.register_table(table_class=Manager) self.assertEqual(config.get_table_with_name("Manager"), Manager) @@ -58,7 +111,7 @@ def test_table_finder(self): """ Should return all Table subclasses. """ - tables = table_finder(modules=["tests.example_app.tables"]) + tables = table_finder(modules=["tests.example_apps.music.tables"]) table_class_names = [i.__name__ for i in tables] table_class_names.sort() @@ -68,6 +121,7 @@ def test_table_finder(self): [ "Band", "Concert", + "Instrument", "Manager", "Poster", "RecordingStudio", @@ -84,7 +138,7 @@ def test_table_finder_coercion(self): """ Should convert a string argument to a list. """ - tables = table_finder(modules="tests.example_app.tables") + tables = table_finder(modules="tests.example_apps.music.tables") table_class_names = [i.__name__ for i in tables] table_class_names.sort() @@ -94,6 +148,7 @@ def test_table_finder_coercion(self): [ "Band", "Concert", + "Instrument", "Manager", "Poster", "RecordingStudio", @@ -108,7 +163,8 @@ def test_include_tags(self): Should return all Table subclasses with a matching tag. """ tables = table_finder( - modules=["tests.example_app.tables"], include_tags=["special"] + modules=["tests.example_apps.music.tables"], + include_tags=["special"], ) table_class_names = [i.__name__ for i in tables] @@ -124,7 +180,8 @@ def test_exclude_tags(self): Should return all Table subclasses without the specified tags. """ tables = table_finder( - modules=["tests.example_app.tables"], exclude_tags=["special"] + modules=["tests.example_apps.music.tables"], + exclude_tags=["special"], ) table_class_names = [i.__name__ for i in tables] @@ -135,6 +192,7 @@ def test_exclude_tags(self): [ "Band", "Concert", + "Instrument", "Manager", "RecordingStudio", "Shirt", @@ -142,3 +200,161 @@ def test_exclude_tags(self): "Venue", ], ) + + def test_exclude_imported(self): + """ + Make sure we can excluded imported Tables. + """ + filtered_tables = table_finder( + modules=["tests.conf.example"], + exclude_imported=True, + ) + + self.assertEqual( + [i.__name__ for i in filtered_tables], + ["Musician"], + ) + + # Now try without filtering: + all_tables = table_finder( + modules=["tests.conf.example"], + exclude_imported=False, + ) + + self.assertEqual( + sorted([i.__name__ for i in all_tables]), + ["BaseUser", "Musician"], + ) + + +class TestFinder(TestCase): + def test_get_table_classes(self): + """ + Make sure ``Table`` classes can be retrieved. + """ + finder = Finder() + + self.assertListEqual( + sorted(finder.get_table_classes(), key=lambda i: i.__name__), + [ + Band, + Concert, + Instrument, + Manager, + MegaTable, + Poster, + RecordingStudio, + Shirt, + SmallTable, + Ticket, + Venue, + ], + ) + + self.assertListEqual( + sorted( + finder.get_table_classes(include_apps=["music"]), + key=lambda i: i.__name__, + ), + [ + Band, + Concert, + Instrument, + Manager, + Poster, + RecordingStudio, + Shirt, + Ticket, + Venue, + ], + ) + + self.assertListEqual( + sorted( + finder.get_table_classes(exclude_apps=["music"]), + key=lambda i: i.__name__, + ), + [ + MegaTable, + SmallTable, + ], + ) + + with self.assertRaises(ValueError): + # You shouldn't be allowed to specify both include and exclude. + finder.get_table_classes( + exclude_apps=["music"], include_apps=["mega"] + ) + + def test_sort_app_configs(self): + """ + Make sure we can sort ``AppConfig`` based on their migration + dependencies. + """ + app_config_1 = AppConfig( + app_name="app_1", + migrations_folder_path="", + ) + + app_config_1._migration_dependency_app_configs = [ + AppConfig( + app_name="app_2", + migrations_folder_path="", + ) + ] + + app_config_2 = AppConfig( + app_name="app_2", + migrations_folder_path="", + ) + + app_config_2._migration_dependency_app_configs = [] + + sorted_app_configs = Finder().sort_app_configs( + [app_config_2, app_config_1] + ) + + self.assertListEqual( + [i.app_name for i in sorted_app_configs], ["app_2", "app_1"] + ) + + +class TestPiccoloConfUpdater(TestCase): + + def test_modify_app_registry_src(self): + """ + Make sure the `piccolo_conf.py` source code can be modified + successfully. + """ + updater = PiccoloConfUpdater() + + new_src = updater._modify_app_registry_src( + src="APP_REGISTRY = AppRegistry(apps=[])", + app_module="music.piccolo_app", + ) + self.assertEqual( + new_src.strip(), + 'APP_REGISTRY = AppRegistry(apps=["music.piccolo_app"])', + ) + + def test_register_app(self): + """ + Make sure the new contents get written to disk. + """ + temp_dir = tempfile.gettempdir() + piccolo_conf_path = pathlib.Path(temp_dir) / "piccolo_conf.py" + + src = "APP_REGISTRY = AppRegistry(apps=[])" + + with open(piccolo_conf_path, "wt") as f: + f.write(src) + + updater = PiccoloConfUpdater(piccolo_conf_path=str(piccolo_conf_path)) + updater.register_app(app_module="music.piccolo_app") + + with open(piccolo_conf_path) as f: + contents = f.read().strip() + + self.assertEqual( + contents, 'APP_REGISTRY = AppRegistry(apps=["music.piccolo_app"])' + ) diff --git a/tests/conftest.py b/tests/conftest.py index b1c7257be..8411ebc38 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,25 +6,42 @@ async def drop_tables(): - for table in [ - "venue", + tables = [ + "ticket", "concert", + "venue", "band", "manager", - "ticket", "poster", "migration", "musician", "my_table", "recording_studio", + "instrument", "shirt", - ]: - await ENGINE._run_in_new_connection(f"DROP TABLE IF EXISTS {table}") + "instrument", + "mega_table", + "small_table", + ] + assert ENGINE is not None + + if ENGINE.engine_type == "sqlite": + # SQLite doesn't allow us to drop more than one table at a time. + for table in tables: + await ENGINE._run_in_new_connection( + f"DROP TABLE IF EXISTS {table}" + ) + else: + table_str = ", ".join(tables) + await ENGINE._run_in_new_connection( + f"DROP TABLE IF EXISTS {table_str} CASCADE" + ) def pytest_sessionstart(session): """ - Make sure all the tables have been dropped. + Make sure all the tables have been dropped, just in case a previous test + run was aborted part of the way through. https://docs.pytest.org/en/latest/reference.html#_pytest.hookspec.pytest_configure """ diff --git a/tests/engine/test_extra_nodes.py b/tests/engine/test_extra_nodes.py new file mode 100644 index 000000000..42a2703f9 --- /dev/null +++ b/tests/engine/test_extra_nodes.py @@ -0,0 +1,40 @@ +from typing import cast +from unittest import TestCase +from unittest.mock import MagicMock + +from piccolo.columns.column_types import Varchar +from piccolo.engine import engine_finder +from piccolo.engine.postgres import PostgresEngine +from piccolo.table import Table +from tests.base import AsyncMock, engines_only + + +@engines_only("postgres", "cockroach") +class TestExtraNodes(TestCase): + def test_extra_nodes(self): + """ + Make sure that other nodes can be queried. + """ + # Get the test database credentials: + test_engine = engine_finder() + assert test_engine is not None + + test_engine = cast(PostgresEngine, test_engine) + + EXTRA_NODE = MagicMock(spec=PostgresEngine(config=test_engine.config)) + EXTRA_NODE.run_querystring = AsyncMock(return_value=[]) + + DB = PostgresEngine( + config=test_engine.config, extra_nodes={"read_1": EXTRA_NODE} + ) + + class Manager(Table, db=DB): + name = Varchar() + + # Make sure the node is queried + Manager.select().run_sync(node="read_1") + self.assertTrue(EXTRA_NODE.run_querystring.called) + + # Make sure that a non existent node raises an error + with self.assertRaises(KeyError): + Manager.select().run_sync(node="read_2") diff --git a/tests/engine/test_logging.py b/tests/engine/test_logging.py new file mode 100644 index 000000000..2e6ec3d3a --- /dev/null +++ b/tests/engine/test_logging.py @@ -0,0 +1,39 @@ +from unittest.mock import patch + +from tests.base import DBTestCase +from tests.example_apps.music.tables import Manager + + +class TestLogging(DBTestCase): + def tearDown(self): + Manager._meta.db.log_queries = False + Manager._meta.db.log_responses = False + super().tearDown() + + def test_log_queries(self): + Manager._meta.db.log_queries = True + + with patch("piccolo.engine.base.Engine.print_query") as print_query: + Manager.select().run_sync() + print_query.assert_called_once() + + def test_log_responses(self): + Manager._meta.db.log_responses = True + + with patch( + "piccolo.engine.base.Engine.print_response" + ) as print_response: + Manager.select().run_sync() + print_response.assert_called_once() + + def test_log_queries_and_responses(self): + Manager._meta.db.log_queries = True + Manager._meta.db.log_responses = True + + with patch("piccolo.engine.base.Engine.print_query") as print_query: + with patch( + "piccolo.engine.base.Engine.print_response" + ) as print_response: + Manager.select().run_sync() + print_query.assert_called_once() + print_response.assert_called_once() diff --git a/tests/engine/test_nested_transaction.py b/tests/engine/test_nested_transaction.py index 3e0443658..71d519b79 100644 --- a/tests/engine/test_nested_transaction.py +++ b/tests/engine/test_nested_transaction.py @@ -5,9 +5,8 @@ from piccolo.engine.exceptions import TransactionError from piccolo.engine.sqlite import SQLiteEngine from piccolo.table import Table - -from ..base import DBTestCase, sqlite_only -from ..example_app.tables import Manager +from tests.base import DBTestCase, sqlite_only +from tests.example_apps.music.tables import Manager ENGINE_1 = SQLiteEngine(path="engine1.sqlite") ENGINE_2 = SQLiteEngine(path="engine2.sqlite") @@ -46,11 +45,13 @@ async def run_nested(self): self.assertTrue(await Musician.table_exists().run()) musician = await Musician.select("name").first().run() - self.assertTrue(musician["name"] == "Bob") + assert musician is not None + self.assertEqual(musician["name"], "Bob") self.assertTrue(await Roadie.table_exists().run()) roadie = await Roadie.select("name").first().run() - self.assertTrue(roadie["name"] == "Dave") + assert roadie is not None + self.assertEqual(roadie["name"], "Dave") def test_nested(self): asyncio.run(self.run_nested()) @@ -61,12 +62,25 @@ async def run_nested(self): """ Nested transactions currently aren't permitted in a connection. """ + # allow_nested=False with self.assertRaises(TransactionError): async with Manager._meta.db.transaction(): await Manager(name="Bob").save().run() - async with Manager._meta.db.transaction(): - await Manager(name="Dave").save().run() + async with Manager._meta.db.transaction(allow_nested=False): + pass + + # allow_nested=True + async with Manager._meta.db.transaction(): + async with Manager._meta.db.transaction(): + # Shouldn't raise an exception + pass + + # Utilise returned transaction. + async with Manager._meta.db.transaction(): + async with Manager._meta.db.transaction() as transaction: + await Manager(name="Dave").save().run() + await transaction.rollback() def test_nested(self): asyncio.run(self.run_nested()) diff --git a/tests/engine/test_pool.py b/tests/engine/test_pool.py index e39c66379..28f2db1c3 100644 --- a/tests/engine/test_pool.py +++ b/tests/engine/test_pool.py @@ -1,15 +1,20 @@ import asyncio +import os +import tempfile +from typing import cast +from unittest import TestCase +from unittest.mock import call, patch from piccolo.engine.postgres import PostgresEngine +from piccolo.engine.sqlite import SQLiteEngine +from tests.base import DBTestCase, engine_is, engines_only, sqlite_only +from tests.example_apps.music.tables import Manager -from ..base import DBTestCase, postgres_only -from ..example_app.tables import Manager - -@postgres_only +@engines_only("postgres", "cockroach") class TestPool(DBTestCase): - async def _create_pool(self): - engine: PostgresEngine = Manager._meta.db + async def _create_pool(self) -> None: + engine = cast(PostgresEngine, Manager._meta.db) await engine.start_connection_pool() assert engine.pool is not None @@ -22,7 +27,7 @@ async def _make_query(self): await Manager(name="Bob").save().run() response = await Manager.select().run() - self.assertTrue("Bob" in [i["name"] for i in response]) + self.assertIn("Bob", [i["name"] for i in response]) await Manager._meta.db.close_connection_pool() @@ -33,7 +38,12 @@ async def _make_many_queries(self): async def get_data(): response = await Manager.select().run() - self.assertEqual(response, [{"id": 1, "name": "Bob"}]) + if engine_is("cockroach"): + self.assertEqual( + response, [{"id": response[0]["id"], "name": "Bob"}] + ) + else: + self.assertEqual(response, [{"id": 1, "name": "Bob"}]) await asyncio.gather(*[get_data() for _ in range(500)]) @@ -59,10 +69,10 @@ def test_many_queries(self): asyncio.run(self._make_many_queries()) -@postgres_only +@engines_only("postgres", "cockroach") class TestPoolProxyMethods(DBTestCase): - async def _create_pool(self): - engine: PostgresEngine = Manager._meta.db + async def _create_pool(self) -> None: + engine = cast(PostgresEngine, Manager._meta.db) # Deliberate typo ('nnn'): await engine.start_connnection_pool() @@ -78,3 +88,36 @@ def test_proxy_methods(self): work, to ensure backwards compatibility. """ asyncio.run(self._create_pool()) + + +@sqlite_only +class TestConnectionPoolWarning(TestCase): + async def _create_pool(self): + sqlite_file = os.path.join(tempfile.gettempdir(), "engine.sqlite") + engine = SQLiteEngine(path=sqlite_file) + + with patch("piccolo.engine.base.colored_warning") as colored_warning: + await engine.start_connection_pool() + await engine.close_connection_pool() + + self.assertEqual( + colored_warning.call_args_list, + [ + call( + "Connection pooling is not supported for sqlite.", + stacklevel=3, + ), + call( + "Connection pooling is not supported for sqlite.", + stacklevel=3, + ), + ], + ) + + def test_warnings(self): + """ + Make sure that when trying to start and close a connection pool with + SQLite, a warning is printed out, as connection pools aren't currently + supported. + """ + asyncio.run(self._create_pool()) diff --git a/tests/engine/test_transaction.py b/tests/engine/test_transaction.py index 0f248a943..d381f5d14 100644 --- a/tests/engine/test_transaction.py +++ b/tests/engine/test_transaction.py @@ -1,8 +1,14 @@ import asyncio +from typing import cast from unittest import TestCase -from ..base import postgres_only -from ..example_app.tables import Band, Manager +import pytest + +from piccolo.engine.sqlite import SQLiteEngine, TransactionType +from piccolo.table import drop_db_tables_sync +from piccolo.utils.sync import run_sync +from tests.base import engines_only +from tests.example_apps.music.tables import Band, Manager class TestAtomic(TestCase): @@ -10,31 +16,62 @@ def test_error(self): """ Make sure queries in a transaction aren't committed if a query fails. """ - transaction = Band._meta.db.atomic() - transaction.add( + atomic = Band._meta.db.atomic() + atomic.add( Manager.create_table(), Band.create_table(), Band.raw("MALFORMED QUERY ... SHOULD ERROR"), ) try: - transaction.run_sync() + atomic.run_sync() except Exception: pass self.assertTrue(not Band.table_exists().run_sync()) self.assertTrue(not Manager.table_exists().run_sync()) def test_succeeds(self): - transaction = Band._meta.db.atomic() - transaction.add(Manager.create_table(), Band.create_table()) - transaction.run_sync() + """ + Make sure that when atomic is run successfully the database is modified + accordingly. + """ + atomic = Band._meta.db.atomic() + atomic.add(Manager.create_table(), Band.create_table()) + atomic.run_sync() self.assertTrue(Band.table_exists().run_sync()) self.assertTrue(Manager.table_exists().run_sync()) - transaction.add( - Band.alter().drop_table(), Manager.alter().drop_table() - ) - transaction.run_sync() + drop_db_tables_sync(Band, Manager) + + @engines_only("postgres", "cockroach") + def test_pool(self) -> None: + """ + Make sure atomic works correctly when a connection pool is active. + """ + + async def run() -> None: + """ + We have to run this async function, so we can use a connection + pool. + """ + engine = Band._meta.db + await engine.start_connection_pool() + + atomic = engine.atomic() + atomic.add( + Manager.create_table(), + Band.create_table(), + ) + + await atomic.run() + await engine.close_connection_pool() + + run_sync(run()) + + self.assertTrue(Band.table_exists().run_sync()) + self.assertTrue(Manager.table_exists().run_sync()) + + drop_db_tables_sync(Band, Manager) class TestTransaction(TestCase): @@ -73,7 +110,35 @@ async def run_transaction(): self.assertTrue(Band.table_exists().run_sync()) self.assertTrue(Manager.table_exists().run_sync()) - @postgres_only + def test_manual_commit(self): + """ + The context manager automatically commits changes, but we also + allow the user to do it manually. + """ + + async def run_transaction(): + async with Band._meta.db.transaction() as transaction: + await Manager.create_table() + await transaction.commit() + + asyncio.run(run_transaction()) + self.assertTrue(Manager.table_exists().run_sync()) + + def test_manual_rollback(self): + """ + The context manager will automatically rollback changes if an exception + is raised, but we also allow the user to do it manually. + """ + + async def run_transaction(): + async with Band._meta.db.transaction() as transaction: + await Manager.create_table() + await transaction.rollback() + + asyncio.run(run_transaction()) + self.assertFalse(Manager.table_exists().run_sync()) + + @engines_only("postgres") def test_transaction_id(self): """ An extra sanity check, that the transaction id is the same for each @@ -92,8 +157,154 @@ async def run_transaction(): return [i[0]["txid_current"] for i in responses] txids = asyncio.run(run_transaction()) - assert len(set(txids)) == 1 + self.assertEqual(len(set(txids)), 1) # Now run it again and make sure the transaction ids differ. next_txids = asyncio.run(run_transaction()) - assert txids != next_txids + self.assertNotEqual(txids, next_txids) + + +class TestTransactionExists(TestCase): + def test_exists(self): + """ + Make sure we can detect when code is within a transaction. + """ + engine = cast(SQLiteEngine, Manager._meta.db) + + async def run_inside_transaction(): + async with engine.transaction(): + return engine.transaction_exists() + + self.assertTrue(asyncio.run(run_inside_transaction())) + + async def run_outside_transaction(): + return engine.transaction_exists() + + self.assertFalse(asyncio.run(run_outside_transaction())) + + +@engines_only("sqlite") +class TestTransactionType(TestCase): + def setUp(self): + Manager.create_table().run_sync() + + def tearDown(self): + Manager.alter().drop_table().run_sync() + + def test_transaction(self): + """ + With SQLite, we can specify the transaction type. This helps when + we want to do concurrent writes, to avoid locking the database. + + https://github.com/piccolo-orm/piccolo/issues/687 + """ + engine = cast(SQLiteEngine, Manager._meta.db) + + async def run_transaction(name: str): + async with engine.transaction( + transaction_type=TransactionType.immediate + ): + # This does a SELECT followed by an INSERT, so is a good test. + # If using TransactionType.deferred it would fail because + # the database will become locked. + await Manager.objects().get_or_create(Manager.name == name) + + manager_names = [f"Manager_{i}" for i in range(1, 10)] + + async def run_all(): + """ + Run all of the transactions concurrently. + """ + await asyncio.gather( + *[run_transaction(name=name) for name in manager_names] + ) + + asyncio.run(run_all()) + + # Make sure it all ran effectively. + self.assertListEqual( + Manager.select(Manager.name) + .order_by(Manager.name) + .output(as_list=True) + .run_sync(), + manager_names, + ) + + def test_atomic(self): + """ + Similar to above, but with ``Atomic``. + """ + engine = cast(SQLiteEngine, Manager._meta.db) + + async def run_atomic(name: str): + atomic = engine.atomic(transaction_type=TransactionType.immediate) + atomic.add(Manager.objects().get_or_create(Manager.name == name)) + await atomic.run() + + manager_names = [f"Manager_{i}" for i in range(1, 10)] + + async def run_all(): + """ + Run all of the transactions concurrently. + """ + await asyncio.gather( + *[run_atomic(name=name) for name in manager_names] + ) + + asyncio.run(run_all()) + + # Make sure it all ran effectively. + self.assertListEqual( + Manager.select(Manager.name) + .order_by(Manager.name) + .output(as_list=True) + .run_sync(), + manager_names, + ) + + +class TestSavepoint(TestCase): + def setUp(self): + Manager.create_table().run_sync() + + def tearDown(self): + Manager.alter().drop_table().run_sync() + + def test_savepoint(self): + async def run_test(): + async with Manager._meta.db.transaction() as transaction: + await Manager.insert(Manager(name="Manager 1")) + savepoint = await transaction.savepoint() + await Manager.insert(Manager(name="Manager 2")) + await savepoint.rollback_to() + + run_sync(run_test()) + + self.assertListEqual( + Manager.select(Manager.name).run_sync(), [{"name": "Manager 1"}] + ) + + def test_named_savepoint(self): + async def run_test(): + async with Manager._meta.db.transaction() as transaction: + await Manager.insert(Manager(name="Manager 1")) + await transaction.savepoint("my_savepoint") + await Manager.insert(Manager(name="Manager 2")) + await transaction.rollback_to("my_savepoint") + + run_sync(run_test()) + + self.assertListEqual( + Manager.select(Manager.name).run_sync(), [{"name": "Manager 1"}] + ) + + def test_savepoint_sqli_checks(self): + # Added to test the fix for GHSA-xq59-7jf3-rjc6 + async def run_test(): + async with Manager._meta.db.transaction() as transaction: + await transaction.savepoint( + "my_savepoint; SELECT * FROM Manager" + ) + + with pytest.raises(ValueError): + run_sync(run_test()) diff --git a/tests/engine/test_version_parsing.py b/tests/engine/test_version_parsing.py index d49ce65dd..08cd7a7c2 100644 --- a/tests/engine/test_version_parsing.py +++ b/tests/engine/test_version_parsing.py @@ -2,10 +2,10 @@ from piccolo.engine.postgres import PostgresEngine -from ..base import postgres_only +from ..base import engines_only -@postgres_only +@engines_only("postgres", "cockroach") class TestVersionParsing(TestCase): def test_version_parsing(self): """ diff --git a/tests/example_app/piccolo_app.py b/tests/example_app/piccolo_app.py deleted file mode 100644 index d74c916e2..000000000 --- a/tests/example_app/piccolo_app.py +++ /dev/null @@ -1,35 +0,0 @@ -import os - -from piccolo.conf.apps import AppConfig - -from .tables import ( - Band, - Concert, - Manager, - Poster, - RecordingStudio, - Shirt, - Ticket, - Venue, -) - -CURRENT_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) - - -APP_CONFIG = AppConfig( - app_name="example_app", - table_classes=[ - Manager, - Band, - Venue, - Concert, - Ticket, - Poster, - Shirt, - RecordingStudio, - ], - migrations_folder_path=os.path.join( - CURRENT_DIRECTORY, "piccolo_migrations" - ), - commands=[], -) diff --git a/tests/example_apps/__init__.py b/tests/example_apps/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/example_apps/mega/__init__.py b/tests/example_apps/mega/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/example_apps/mega/piccolo_app.py b/tests/example_apps/mega/piccolo_app.py new file mode 100644 index 000000000..f565bb5aa --- /dev/null +++ b/tests/example_apps/mega/piccolo_app.py @@ -0,0 +1,17 @@ +import os + +from piccolo.conf.apps import AppConfig + +from .tables import MegaTable, SmallTable + +CURRENT_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) + + +APP_CONFIG = AppConfig( + app_name="mega", + table_classes=[MegaTable, SmallTable], + migrations_folder_path=os.path.join( + CURRENT_DIRECTORY, "piccolo_migrations" + ), + commands=[], +) diff --git a/tests/example_apps/mega/piccolo_migrations/2021-09-20T21-23-25-698988.py b/tests/example_apps/mega/piccolo_migrations/2021-09-20T21-23-25-698988.py new file mode 100644 index 000000000..f204dbb7c --- /dev/null +++ b/tests/example_apps/mega/piccolo_migrations/2021-09-20T21-23-25-698988.py @@ -0,0 +1,450 @@ +from decimal import Decimal + +from piccolo.apps.migrations.auto import MigrationManager +from piccolo.columns.base import OnDelete, OnUpdate +from piccolo.columns.column_types import ( + JSON, + JSONB, + UUID, + BigInt, + Boolean, + Bytea, + Date, + DoublePrecision, + ForeignKey, + Integer, + Interval, + Numeric, + Real, + Serial, + SmallInt, + Text, + Timestamp, + Timestamptz, + Varchar, +) +from piccolo.columns.defaults.date import DateNow +from piccolo.columns.defaults.interval import IntervalCustom +from piccolo.columns.defaults.timestamp import TimestampNow +from piccolo.columns.defaults.timestamptz import TimestamptzNow +from piccolo.columns.defaults.uuid import UUID4 +from piccolo.columns.indexes import IndexMethod +from piccolo.table import Table + + +class SmallTable(Table, tablename="small_table"): + id = Serial( + null=False, + primary_key=True, + unique=False, + index=False, + index_method=IndexMethod.btree, + choices=None, + ) + + +ID = "2021-09-20T21:23:25:698988" +VERSION = "0.49.0" +DESCRIPTION = "" + + +async def forwards(): + manager = MigrationManager( + migration_id=ID, app_name="mega", description=DESCRIPTION + ) + + manager.add_table("MegaTable", tablename="mega_table") + + manager.add_table("SmallTable", tablename="small_table") + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="bigint_col", + column_class_name="BigInt", + column_class=BigInt, + params={ + "default": 0, + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="boolean_col", + column_class_name="Boolean", + column_class=Boolean, + params={ + "default": False, + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="bytea_col", + column_class_name="Bytea", + column_class=Bytea, + params={ + "default": b"", + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="date_col", + column_class_name="Date", + column_class=Date, + params={ + "default": DateNow(), + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="foreignkey_col", + column_class_name="ForeignKey", + column_class=ForeignKey, + params={ + "references": SmallTable, + "on_delete": OnDelete.cascade, + "on_update": OnUpdate.cascade, + "null": True, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="integer_col", + column_class_name="Integer", + column_class=Integer, + params={ + "default": 0, + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="interval_col", + column_class_name="Interval", + column_class=Interval, + params={ + "default": IntervalCustom( + weeks=0, + days=0, + hours=0, + minutes=0, + seconds=0, + milliseconds=0, + microseconds=0, + ), + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="json_col", + column_class_name="JSON", + column_class=JSON, + params={ + "default": "{}", + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="jsonb_col", + column_class_name="JSONB", + column_class=JSONB, + params={ + "default": "{}", + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="numeric_col", + column_class_name="Numeric", + column_class=Numeric, + params={ + "default": Decimal("0"), + "digits": (5, 2), + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="real_col", + column_class_name="Real", + column_class=Real, + params={ + "default": 0.0, + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="double_precision_col", + column_class_name="DoublePrecision", + column_class=DoublePrecision, + params={ + "default": 0.0, + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="smallint_col", + column_class_name="SmallInt", + column_class=SmallInt, + params={ + "default": 0, + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="text_col", + column_class_name="Text", + column_class=Text, + params={ + "default": "", + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="timestamp_col", + column_class_name="Timestamp", + column_class=Timestamp, + params={ + "default": TimestampNow(), + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="timestamptz_col", + column_class_name="Timestamptz", + column_class=Timestamptz, + params={ + "default": TimestamptzNow(), + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="uuid_col", + column_class_name="UUID", + column_class=UUID, + params={ + "default": UUID4(), + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="varchar_col", + column_class_name="Varchar", + column_class=Varchar, + params={ + "length": 255, + "default": "", + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="unique_col", + column_class_name="Varchar", + column_class=Varchar, + params={ + "length": 255, + "default": "", + "null": False, + "primary_key": False, + "unique": True, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="null_col", + column_class_name="Varchar", + column_class=Varchar, + params={ + "length": 255, + "default": "", + "null": True, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="MegaTable", + tablename="mega_table", + column_name="not_null_col", + column_class_name="Varchar", + column_class=Varchar, + params={ + "length": 255, + "default": "", + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + manager.add_column( + table_class_name="SmallTable", + tablename="small_table", + column_name="varchar_col", + column_class_name="Varchar", + column_class=Varchar, + params={ + "length": 255, + "default": "", + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + return manager diff --git a/tests/example_apps/mega/piccolo_migrations/__init__.py b/tests/example_apps/mega/piccolo_migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/example_apps/mega/tables.py b/tests/example_apps/mega/tables.py new file mode 100644 index 000000000..8947bb3e2 --- /dev/null +++ b/tests/example_apps/mega/tables.py @@ -0,0 +1,94 @@ +""" +This is a useful table when we want to test all possible column types. +""" + +from piccolo.columns.column_types import ( + JSON, + JSONB, + UUID, + BigInt, + Boolean, + Bytea, + Date, + DoublePrecision, + ForeignKey, + Integer, + Interval, + Numeric, + Real, + SmallInt, + Text, + Timestamp, + Timestamptz, + Varchar, +) +from piccolo.engine.finder import engine_finder +from piccolo.table import Table + +engine = engine_finder() + + +class SmallTable(Table): + varchar_col = Varchar() + + +if engine.engine_type != "cockroach": # type: ignore + + class MegaTable(Table): # type: ignore + """ + A table containing all of the column types, different column kwargs. + """ + + bigint_col = BigInt() + boolean_col = Boolean() + bytea_col = Bytea() + date_col = Date() + foreignkey_col = ForeignKey(SmallTable) + integer_col = Integer() + interval_col = Interval() + json_col = JSON() + jsonb_col = JSONB() + numeric_col = Numeric(digits=(5, 2)) + real_col = Real() + double_precision_col = DoublePrecision() + smallint_col = SmallInt() + text_col = Text() + timestamp_col = Timestamp() + timestamptz_col = Timestamptz() + uuid_col = UUID() + varchar_col = Varchar() + + unique_col = Varchar(unique=True) + null_col = Varchar(null=True) + not_null_col = Varchar(null=False) + +else: + + class MegaTable(Table): # type: ignore + """ + Special version for Cockroach. + A table containing all of the column types, different column kwargs. + """ + + bigint_col = BigInt() + boolean_col = Boolean() + bytea_col = Bytea() + date_col = Date() + foreignkey_col = ForeignKey(SmallTable) + integer_col = BigInt() + interval_col = Interval() + json_col = JSONB() + jsonb_col = JSONB() + numeric_col = Numeric(digits=(5, 2)) + real_col = Real() + double_precision_col = DoublePrecision() + smallint_col = SmallInt() + text_col = Text() + timestamp_col = Timestamp() + timestamptz_col = Timestamptz() + uuid_col = UUID() + varchar_col = Varchar() + + unique_col = Varchar(unique=True) + null_col = Varchar(null=True) + not_null_col = Varchar(null=False) diff --git a/tests/example_apps/music/__init__.py b/tests/example_apps/music/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/example_apps/music/piccolo_app.py b/tests/example_apps/music/piccolo_app.py new file mode 100644 index 000000000..cb473faf7 --- /dev/null +++ b/tests/example_apps/music/piccolo_app.py @@ -0,0 +1,15 @@ +import os + +from piccolo.conf.apps import AppConfig, table_finder + +CURRENT_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) + + +APP_CONFIG = AppConfig( + app_name="music", + table_classes=table_finder(["tests.example_apps.music.tables"]), + migrations_folder_path=os.path.join( + CURRENT_DIRECTORY, "piccolo_migrations" + ), + commands=[], +) diff --git a/tests/example_app/piccolo_migrations/2020-12-17T18-44-30.py b/tests/example_apps/music/piccolo_migrations/2020-12-17T18-44-30.py similarity index 85% rename from tests/example_app/piccolo_migrations/2020-12-17T18-44-30.py rename to tests/example_apps/music/piccolo_migrations/2020-12-17T18-44-30.py index 94943796a..be87769ea 100644 --- a/tests/example_app/piccolo_migrations/2020-12-17T18-44-30.py +++ b/tests/example_apps/music/piccolo_migrations/2020-12-17T18-44-30.py @@ -12,10 +12,10 @@ class Manager(Table, tablename="manager"): async def forwards(): - manager = MigrationManager(migration_id=ID, app_name="example_app") + manager = MigrationManager(migration_id=ID, app_name="music") - manager.add_table("Manager", tablename="manager") manager.add_table("Band", tablename="band") + manager.add_table("Manager", tablename="manager") manager.add_column( table_class_name="Band", @@ -26,8 +26,7 @@ async def forwards(): "length": 50, "default": "", "null": False, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, @@ -44,8 +43,7 @@ async def forwards(): "on_update": OnUpdate.cascade, "default": None, "null": True, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, @@ -59,8 +57,7 @@ async def forwards(): params={ "default": 0, "null": False, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, @@ -75,8 +72,7 @@ async def forwards(): "length": 50, "default": "", "null": False, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, diff --git a/tests/example_app/piccolo_migrations/2020-12-17T18-44-39.py b/tests/example_apps/music/piccolo_migrations/2020-12-17T18-44-39.py similarity index 87% rename from tests/example_app/piccolo_migrations/2020-12-17T18-44-39.py rename to tests/example_apps/music/piccolo_migrations/2020-12-17T18-44-39.py index b8439da2e..48048ce5a 100644 --- a/tests/example_app/piccolo_migrations/2020-12-17T18-44-39.py +++ b/tests/example_apps/music/piccolo_migrations/2020-12-17T18-44-39.py @@ -18,7 +18,7 @@ class Venue(Table, tablename="venue"): async def forwards(): - manager = MigrationManager(migration_id=ID, app_name="example_app") + manager = MigrationManager(migration_id=ID, app_name="music") manager.add_table("Ticket", tablename="ticket") manager.add_table("Venue", tablename="venue") @@ -33,8 +33,7 @@ async def forwards(): "default": Decimal("0"), "digits": (5, 2), "null": False, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, @@ -51,8 +50,7 @@ async def forwards(): "on_update": OnUpdate.cascade, "default": None, "null": True, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, @@ -69,8 +67,7 @@ async def forwards(): "on_update": OnUpdate.cascade, "default": None, "null": True, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, @@ -87,8 +84,7 @@ async def forwards(): "on_update": OnUpdate.cascade, "default": None, "null": True, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, @@ -103,8 +99,7 @@ async def forwards(): "length": 100, "default": "", "null": False, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, @@ -118,8 +113,7 @@ async def forwards(): params={ "default": 0, "null": False, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, diff --git a/tests/example_app/piccolo_migrations/2020-12-17T18-44-44.py b/tests/example_apps/music/piccolo_migrations/2020-12-17T18-44-44.py similarity index 79% rename from tests/example_app/piccolo_migrations/2020-12-17T18-44-44.py rename to tests/example_apps/music/piccolo_migrations/2020-12-17T18-44-44.py index fadcc5b48..52a94fff8 100644 --- a/tests/example_app/piccolo_migrations/2020-12-17T18-44-44.py +++ b/tests/example_apps/music/piccolo_migrations/2020-12-17T18-44-44.py @@ -5,7 +5,7 @@ async def forwards(): - manager = MigrationManager(migration_id=ID, app_name="example_app") + manager = MigrationManager(migration_id=ID, app_name="music") manager.add_table("Poster", tablename="poster") @@ -17,8 +17,7 @@ async def forwards(): params={ "default": "", "null": False, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, }, diff --git a/tests/example_app/piccolo_migrations/2021-07-25T22-38-48-009306.py b/tests/example_apps/music/piccolo_migrations/2021-07-25T22-38-48-009306.py similarity index 88% rename from tests/example_app/piccolo_migrations/2021-07-25T22-38-48-009306.py rename to tests/example_apps/music/piccolo_migrations/2021-07-25T22-38-48-009306.py index 1277d396b..0bddaf7cf 100644 --- a/tests/example_app/piccolo_migrations/2021-07-25T22-38-48-009306.py +++ b/tests/example_apps/music/piccolo_migrations/2021-07-25T22-38-48-009306.py @@ -9,7 +9,7 @@ async def forwards(): - manager = MigrationManager(migration_id=ID, app_name="example_app") + manager = MigrationManager(migration_id=ID, app_name="music") manager.add_table("Shirt", tablename="shirt") @@ -25,8 +25,7 @@ async def forwards(): "length": 1, "default": "l", "null": False, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, "index_method": IndexMethod.btree, @@ -45,8 +44,7 @@ async def forwards(): params={ "default": "{}", "null": False, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, "index_method": IndexMethod.btree, @@ -63,8 +61,7 @@ async def forwards(): params={ "default": "{}", "null": False, - "primary": False, - "key": False, + "primary_key": False, "unique": False, "index": False, "index_method": IndexMethod.btree, diff --git a/tests/example_apps/music/piccolo_migrations/2021-09-06T13-58-23-024723.py b/tests/example_apps/music/piccolo_migrations/2021-09-06T13-58-23-024723.py new file mode 100644 index 000000000..5f9661397 --- /dev/null +++ b/tests/example_apps/music/piccolo_migrations/2021-09-06T13-58-23-024723.py @@ -0,0 +1,48 @@ +from piccolo.apps.migrations.auto import MigrationManager +from piccolo.columns.base import OnDelete, OnUpdate +from piccolo.columns.column_types import ForeignKey, Serial +from piccolo.columns.indexes import IndexMethod +from piccolo.table import Table + + +class Concert(Table, tablename="concert"): + id = Serial( + null=False, + primary_key=True, + unique=False, + index=False, + index_method=IndexMethod.btree, + choices=None, + ) + + +ID = "2021-09-06T13:58:23:024723" +VERSION = "0.43.0" +DESCRIPTION = "" + + +async def forwards(): + manager = MigrationManager( + migration_id=ID, app_name="music", description=DESCRIPTION + ) + + manager.add_column( + table_class_name="Ticket", + tablename="ticket", + column_name="concert", + column_class_name="ForeignKey", + column_class=ForeignKey, + params={ + "references": Concert, + "on_delete": OnDelete.cascade, + "on_update": OnUpdate.cascade, + "null": True, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + }, + ) + + return manager diff --git a/tests/example_apps/music/piccolo_migrations/2021-11-13T14-01-46-114725.py b/tests/example_apps/music/piccolo_migrations/2021-11-13T14-01-46-114725.py new file mode 100644 index 000000000..e7db0f3a9 --- /dev/null +++ b/tests/example_apps/music/piccolo_migrations/2021-11-13T14-01-46-114725.py @@ -0,0 +1,24 @@ +from piccolo.apps.migrations.auto import MigrationManager +from piccolo.columns.column_types import Integer + +ID = "2021-11-13T14:01:46:114725" +VERSION = "0.59.0" +DESCRIPTION = "" + + +async def forwards(): + manager = MigrationManager( + migration_id=ID, app_name="music", description=DESCRIPTION + ) + + manager.alter_column( + table_class_name="Venue", + tablename="venue", + column_name="capacity", + params={"secret": True}, + old_params={"secret": False}, + column_class=Integer, + old_column_class=Integer, + ) + + return manager diff --git a/tests/example_apps/music/piccolo_migrations/music_2024_05_28t23_15_41_018844.py b/tests/example_apps/music/piccolo_migrations/music_2024_05_28t23_15_41_018844.py new file mode 100644 index 000000000..a1428eead --- /dev/null +++ b/tests/example_apps/music/piccolo_migrations/music_2024_05_28t23_15_41_018844.py @@ -0,0 +1,84 @@ +from piccolo.apps.migrations.auto.migration_manager import MigrationManager +from piccolo.columns.base import OnDelete, OnUpdate +from piccolo.columns.column_types import ForeignKey, Serial, Varchar +from piccolo.columns.indexes import IndexMethod +from piccolo.table import Table + + +class RecordingStudio(Table, tablename="recording_studio", schema=None): + id = Serial( + null=False, + primary_key=True, + unique=False, + index=False, + index_method=IndexMethod.btree, + choices=None, + db_column_name="id", + secret=False, + ) + + +ID = "2024-05-28T23:15:41:018844" +VERSION = "1.5.1" +DESCRIPTION = "" + + +async def forwards(): + manager = MigrationManager( + migration_id=ID, app_name="music", description=DESCRIPTION + ) + + manager.add_table( + class_name="Instrument", + tablename="instrument", + schema=None, + columns=None, + ) + + manager.add_column( + table_class_name="Instrument", + tablename="instrument", + column_name="name", + db_column_name="name", + column_class_name="Varchar", + column_class=Varchar, + params={ + "length": 255, + "default": "", + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + "db_column_name": None, + "secret": False, + }, + schema=None, + ) + + manager.add_column( + table_class_name="Instrument", + tablename="instrument", + column_name="recording_studio", + db_column_name="recording_studio", + column_class_name="ForeignKey", + column_class=ForeignKey, + params={ + "references": RecordingStudio, + "on_delete": OnDelete.cascade, + "on_update": OnUpdate.cascade, + "target_column": None, + "null": True, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + "db_column_name": None, + "secret": False, + }, + schema=None, + ) + + return manager diff --git a/tests/example_apps/music/piccolo_migrations/music_2024_06_19t18_11_05_793132.py b/tests/example_apps/music/piccolo_migrations/music_2024_06_19t18_11_05_793132.py new file mode 100644 index 000000000..5d884210d --- /dev/null +++ b/tests/example_apps/music/piccolo_migrations/music_2024_06_19t18_11_05_793132.py @@ -0,0 +1,20 @@ +from piccolo.apps.migrations.auto.migration_manager import MigrationManager + +ID = "2024-06-19T18:11:05:793132" +VERSION = "1.11.0" +DESCRIPTION = "An example fake migration" + + +async def forwards(): + manager = MigrationManager( + migration_id=ID, app_name="", description=DESCRIPTION, fake=True + ) + + def run(): + # This should never run, as this migrations is `fake=True`. It's here + # for testing purposes (to make sure it never gets triggered). + print("Running fake migration") + + manager.add_raw(run) + + return manager diff --git a/tests/example_apps/music/tables.py b/tests/example_apps/music/tables.py new file mode 100644 index 000000000..9e0cdbb39 --- /dev/null +++ b/tests/example_apps/music/tables.py @@ -0,0 +1,127 @@ +from enum import Enum + +from piccolo.columns import ( + JSON, + JSONB, + BigInt, + ForeignKey, + Integer, + Numeric, + Serial, + Text, + Varchar, +) +from piccolo.columns.readable import Readable +from piccolo.engine.finder import engine_finder +from piccolo.table import Table + +engine = engine_finder() + +############################################################################### +# Simple example + + +class Manager(Table): + id: Serial + name = Varchar(length=50) + + @classmethod + def get_readable(cls) -> Readable: + return Readable(template="%s", columns=[cls.name]) + + +class Band(Table): + id: Serial + name = Varchar(length=50) + manager = ForeignKey(Manager, null=True) + popularity = ( + BigInt(default=0) + if engine and engine.engine_type == "cockroach" + else Integer(default=0) + ) + + @classmethod + def get_readable(cls) -> Readable: + return Readable(template="%s", columns=[cls.name]) + + +############################################################################### +# More complex + + +class Venue(Table): + id: Serial + name = Varchar(length=100) + capacity = Integer(default=0, secret=True) + + @classmethod + def get_readable(cls) -> Readable: + return Readable(template="%s", columns=[cls.name]) + + +class Concert(Table): + id: Serial + band_1 = ForeignKey(Band) + band_2 = ForeignKey(Band) + venue = ForeignKey(Venue) + + @classmethod + def get_readable(cls) -> Readable: + return Readable( + template="%s and %s at %s, capacity %s", + columns=[ + cls.band_1.name, + cls.band_2.name, + cls.venue.name, + cls.venue.capacity, + ], + ) + + +class Ticket(Table): + id: Serial + concert = ForeignKey(Concert) + price = Numeric(digits=(5, 2)) + + +class Poster(Table, tags=["special"]): + """ + Has tags for tests which need it. + """ + + id: Serial + content = Text() + + +class Shirt(Table): + """ + Used for testing columns with a choices attribute. + """ + + class Size(str, Enum): + small = "s" + medium = "m" + large = "l" + + id: Serial + size = Varchar(length=1, choices=Size, default=Size.large) + + +class RecordingStudio(Table): + """ + Used for testing JSON and JSONB columns. + """ + + id: Serial + facilities = JSON() + facilities_b = JSONB() + + +class Instrument(Table): + """ + Used for testing foreign keys to a table with a JSON column. + """ + + id: Serial + name = Varchar() + recording_studio = ForeignKey(RecordingStudio) diff --git a/tests/example_app/tables.py b/tests/example_apps/music/tables_detailed.py similarity index 65% rename from tests/example_app/tables.py rename to tests/example_apps/music/tables_detailed.py index dd16984ba..f28692396 100644 --- a/tests/example_app/tables.py +++ b/tests/example_apps/music/tables_detailed.py @@ -1,12 +1,26 @@ +# TODO - these are much better example tables than in tables.py, but many +# tests will break if we change them. In the future migrate this file to +# tables.py and fix the tests. + +from datetime import timedelta from enum import Enum from piccolo.columns import ( JSON, JSONB, + UUID, + BigInt, + Boolean, + Bytea, + Date, ForeignKey, Integer, + Interval, Numeric, + SmallInt, Text, + Timestamp, + Timestamptz, Varchar, ) from piccolo.columns.readable import Readable @@ -18,6 +32,7 @@ class Manager(Table): name = Varchar(length=50) + touring = Boolean(default=False) @classmethod def get_readable(cls) -> Readable: @@ -25,6 +40,8 @@ def get_readable(cls) -> Readable: class Band(Table): + label_id = UUID() + date_signed = Date() name = Varchar(length=50) manager = ForeignKey(Manager, null=True) popularity = Integer(default=0) @@ -44,9 +61,15 @@ class Concert(Table): band_2 = ForeignKey(Band) venue = ForeignKey(Venue) + duration = Interval(default=timedelta(weeks=5, days=3)) + net_profit = SmallInt(default=-32768) + class Ticket(Table): + concert = ForeignKey(Concert) price = Numeric(digits=(5, 2)) + purchase_time = Timestamp() + purchase_time_tz = Timestamptz() class Poster(Table, tags=["special"]): @@ -54,6 +77,7 @@ class Poster(Table, tags=["special"]): Has tags for tests which need it. """ + image = Bytea(default=b"\xbd\x78\xd8") content = Text() @@ -75,5 +99,6 @@ class RecordingStudio(Table): Used for testing JSON and JSONB columns. """ - facilities = JSON() + facilities = JSON(default={"amplifier": False, "microphone": True}) facilities_b = JSONB() + records = BigInt(default=9223372036854775807) diff --git a/tests/postgres_conf.py b/tests/postgres_conf.py index 5a795c4fd..af21dcbc5 100644 --- a/tests/postgres_conf.py +++ b/tests/postgres_conf.py @@ -14,4 +14,9 @@ ) -APP_REGISTRY = AppRegistry(apps=["tests.example_app.piccolo_app"]) +APP_REGISTRY = AppRegistry( + apps=[ + "tests.example_apps.music.piccolo_app", + "tests.example_apps.mega.piccolo_app", + ] +) diff --git a/tests/query/functions/__init__.py b/tests/query/functions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/query/functions/base.py b/tests/query/functions/base.py new file mode 100644 index 000000000..1549709a6 --- /dev/null +++ b/tests/query/functions/base.py @@ -0,0 +1,21 @@ +from piccolo.testing.test_case import TableTest +from tests.example_apps.music.tables import Band, Manager + + +class BandTest(TableTest): + tables = [Band, Manager] + + def setUp(self) -> None: + super().setUp() + + manager = Manager({Manager.name: "Guido"}) + manager.save().run_sync() + + band = Band( + { + Band.name: "Pythonistas", + Band.manager: manager, + Band.popularity: 1000, + } + ) + band.save().run_sync() diff --git a/tests/query/functions/test_datetime.py b/tests/query/functions/test_datetime.py new file mode 100644 index 000000000..382f688ab --- /dev/null +++ b/tests/query/functions/test_datetime.py @@ -0,0 +1,113 @@ +import datetime + +from piccolo.columns import Timestamp +from piccolo.query.functions.datetime import ( + Day, + Extract, + Hour, + Minute, + Month, + Second, + Strftime, + Year, +) +from piccolo.table import Table +from piccolo.testing.test_case import TableTest +from tests.base import engines_only, sqlite_only + + +class Concert(Table): + starts = Timestamp() + + +class DatetimeTest(TableTest): + tables = [Concert] + + def setUp(self) -> None: + super().setUp() + self.concert = Concert( + { + Concert.starts: datetime.datetime( + year=2024, month=6, day=14, hour=23, minute=46, second=10 + ) + } + ) + self.concert.save().run_sync() + + +@engines_only("postgres", "cockroach") +class TestExtract(DatetimeTest): + def test_extract(self): + self.assertEqual( + Concert.select( + Extract(Concert.starts, "year", alias="starts_year") + ).run_sync(), + [{"starts_year": self.concert.starts.year}], + ) + + def test_invalid_format(self): + with self.assertRaises(ValueError): + Extract( + Concert.starts, + "abc123", # type: ignore + alias="starts_year", + ) + + +@sqlite_only +class TestStrftime(DatetimeTest): + def test_strftime(self): + self.assertEqual( + Concert.select( + Strftime(Concert.starts, "%Y", alias="starts_year") + ).run_sync(), + [{"starts_year": str(self.concert.starts.year)}], + ) + + +class TestDatabaseAgnostic(DatetimeTest): + def test_year(self): + self.assertEqual( + Concert.select( + Year(Concert.starts, alias="starts_year") + ).run_sync(), + [{"starts_year": self.concert.starts.year}], + ) + + def test_month(self): + self.assertEqual( + Concert.select( + Month(Concert.starts, alias="starts_month") + ).run_sync(), + [{"starts_month": self.concert.starts.month}], + ) + + def test_day(self): + self.assertEqual( + Concert.select(Day(Concert.starts, alias="starts_day")).run_sync(), + [{"starts_day": self.concert.starts.day}], + ) + + def test_hour(self): + self.assertEqual( + Concert.select( + Hour(Concert.starts, alias="starts_hour") + ).run_sync(), + [{"starts_hour": self.concert.starts.hour}], + ) + + def test_minute(self): + self.assertEqual( + Concert.select( + Minute(Concert.starts, alias="starts_minute") + ).run_sync(), + [{"starts_minute": self.concert.starts.minute}], + ) + + def test_second(self): + self.assertEqual( + Concert.select( + Second(Concert.starts, alias="starts_second") + ).run_sync(), + [{"starts_second": self.concert.starts.second}], + ) diff --git a/tests/query/functions/test_functions.py b/tests/query/functions/test_functions.py new file mode 100644 index 000000000..cb306dcc4 --- /dev/null +++ b/tests/query/functions/test_functions.py @@ -0,0 +1,64 @@ +from piccolo.query.functions import Reverse, Upper +from piccolo.querystring import QueryString +from tests.base import engines_skip +from tests.example_apps.music.tables import Band + +from .base import BandTest + + +@engines_skip("sqlite") +class TestNested(BandTest): + """ + Skip the the test for SQLite, as it doesn't support ``Reverse``. + """ + + def test_nested(self): + """ + Make sure we can nest functions. + """ + response = Band.select(Upper(Reverse(Band.name))).run_sync() + self.assertListEqual(response, [{"upper": "SATSINOHTYP"}]) + + def test_nested_with_joined_column(self): + """ + Make sure nested functions can be used on a column from a joined table. + """ + response = Band.select(Upper(Reverse(Band.manager._.name))).run_sync() + self.assertListEqual(response, [{"upper": "ODIUG"}]) + + def test_nested_within_querystring(self): + """ + If we wrap a function in a custom QueryString - make sure the columns + are still accessible, so joins are successful. + """ + response = Band.select( + QueryString("CONCAT({}, '!')", Upper(Band.manager._.name)), + ).run_sync() + + self.assertListEqual(response, [{"concat": "GUIDO!"}]) + + +class TestWhereClause(BandTest): + + def test_where(self): + """ + Make sure where clauses work with functions. + """ + response = ( + Band.select(Band.name) + .where(Upper(Band.name) == "PYTHONISTAS") + .run_sync() + ) + self.assertListEqual(response, [{"name": "Pythonistas"}]) + + def test_where_with_joined_column(self): + """ + Make sure where clauses work with functions, when a joined column is + used. + """ + response = ( + Band.select(Band.name) + .where(Upper(Band.manager._.name) == "GUIDO") + .run_sync() + ) + self.assertListEqual(response, [{"name": "Pythonistas"}]) diff --git a/tests/query/functions/test_math.py b/tests/query/functions/test_math.py new file mode 100644 index 000000000..330645c36 --- /dev/null +++ b/tests/query/functions/test_math.py @@ -0,0 +1,38 @@ +import decimal + +from piccolo.columns import Numeric +from piccolo.query.functions.math import Abs, Ceil, Floor, Round +from piccolo.table import Table +from piccolo.testing.test_case import TableTest + + +class Ticket(Table): + price = Numeric(digits=(5, 2)) + + +class TestMath(TableTest): + + tables = [Ticket] + + def setUp(self): + super().setUp() + self.ticket = Ticket({Ticket.price: decimal.Decimal("36.50")}) + self.ticket.save().run_sync() + + def test_floor(self): + response = Ticket.select(Floor(Ticket.price, alias="price")).run_sync() + self.assertListEqual(response, [{"price": decimal.Decimal("36.00")}]) + + def test_ceil(self): + response = Ticket.select(Ceil(Ticket.price, alias="price")).run_sync() + self.assertListEqual(response, [{"price": decimal.Decimal("37.00")}]) + + def test_abs(self): + self.ticket.price = decimal.Decimal("-1.50") + self.ticket.save().run_sync() + response = Ticket.select(Abs(Ticket.price, alias="price")).run_sync() + self.assertListEqual(response, [{"price": decimal.Decimal("1.50")}]) + + def test_round(self): + response = Ticket.select(Round(Ticket.price, alias="price")).run_sync() + self.assertListEqual(response, [{"price": decimal.Decimal("37.00")}]) diff --git a/tests/query/functions/test_string.py b/tests/query/functions/test_string.py new file mode 100644 index 000000000..bd3a8c2ab --- /dev/null +++ b/tests/query/functions/test_string.py @@ -0,0 +1,57 @@ +import pytest + +from piccolo.query.functions.string import Concat, Upper +from tests.base import engine_version_lt, is_running_sqlite +from tests.example_apps.music.tables import Band + +from .base import BandTest + + +class TestUpper(BandTest): + + def test_column(self): + """ + Make sure we can uppercase a column's value. + """ + response = Band.select(Upper(Band.name)).run_sync() + self.assertListEqual(response, [{"upper": "PYTHONISTAS"}]) + + def test_alias(self): + response = Band.select(Upper(Band.name, alias="name")).run_sync() + self.assertListEqual(response, [{"name": "PYTHONISTAS"}]) + + def test_joined_column(self): + """ + Make sure we can uppercase a column's value from a joined table. + """ + response = Band.select(Upper(Band.manager._.name)).run_sync() + self.assertListEqual(response, [{"upper": "GUIDO"}]) + + +@pytest.mark.skipif( + is_running_sqlite() and engine_version_lt(3.44), + reason="SQLite version not supported", +) +class TestConcat(BandTest): + + def test_column_and_string(self): + response = Band.select( + Concat(Band.name, "!!!", alias="name") + ).run_sync() + self.assertListEqual(response, [{"name": "Pythonistas!!!"}]) + + def test_column_and_column(self): + response = Band.select( + Concat(Band.name, Band.popularity, alias="name") + ).run_sync() + self.assertListEqual(response, [{"name": "Pythonistas1000"}]) + + def test_join(self): + response = Band.select( + Concat(Band.name, "-", Band.manager._.name, alias="name") + ).run_sync() + self.assertListEqual(response, [{"name": "Pythonistas-Guido"}]) + + def test_min_args(self): + with self.assertRaises(ValueError): + Concat() diff --git a/tests/query/functions/test_type_conversion.py b/tests/query/functions/test_type_conversion.py new file mode 100644 index 000000000..598d9d37c --- /dev/null +++ b/tests/query/functions/test_type_conversion.py @@ -0,0 +1,134 @@ +from piccolo.columns import Integer, Text, Varchar +from piccolo.query.functions import Cast, Length +from tests.example_apps.music.tables import Band, Manager + +from .base import BandTest + + +class TestCast(BandTest): + def test_varchar(self): + """ + Make sure that casting to ``Varchar`` works. + """ + response = Band.select( + Cast( + Band.popularity, + as_type=Varchar(), + ) + ).run_sync() + + self.assertListEqual( + response, + [{"popularity": "1000"}], + ) + + def test_text(self): + """ + Make sure that casting to ``Text`` works. + """ + response = Band.select( + Cast( + Band.popularity, + as_type=Text(), + ) + ).run_sync() + + self.assertListEqual( + response, + [{"popularity": "1000"}], + ) + + def test_integer(self): + """ + Make sure that casting to ``Integer`` works. + """ + Band.update({Band.name: "1111"}, force=True).run_sync() + + response = Band.select( + Cast( + Band.name, + as_type=Integer(), + ) + ).run_sync() + + self.assertListEqual( + response, + [{"name": 1111}], + ) + + def test_join(self): + """ + Make sure that casting works with joins. + """ + Manager.update({Manager.name: "1111"}, force=True).run_sync() + + response = Band.select( + Band.name, + Cast( + Band.manager.name, + as_type=Integer(), + ), + ).run_sync() + + self.assertListEqual( + response, + [ + { + "name": "Pythonistas", + "manager.name": 1111, + } + ], + ) + + def test_nested_inner(self): + """ + Make sure ``Cast`` can be passed into other functions. + """ + Band.update({Band.name: "1111"}, force=True).run_sync() + + response = Band.select( + Length( + Cast( + Band.popularity, + as_type=Varchar(), + ) + ) + ).run_sync() + + self.assertListEqual( + response, + [{"length": 4}], + ) + + def test_nested_outer(self): + """ + Make sure a querystring can be passed into ``Cast`` (meaning it can be + nested). + """ + response = Band.select( + Cast( + Length(Band.name), + as_type=Varchar(), + alias="length", + ) + ).run_sync() + + self.assertListEqual( + response, + [{"length": str(len("Pythonistas"))}], + ) + + def test_where_clause(self): + """ + Make sure ``Cast`` works in a where clause. + """ + response = ( + Band.select(Band.name, Band.popularity) + .where(Cast(Band.popularity, Varchar()) == "1000") + .run_sync() + ) + + self.assertListEqual( + response, + [{"name": "Pythonistas", "popularity": 1000}], + ) diff --git a/tests/query/mixins/__init__.py b/tests/query/mixins/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/query/mixins/test_columns_delegate.py b/tests/query/mixins/test_columns_delegate.py new file mode 100644 index 000000000..e16a13dda --- /dev/null +++ b/tests/query/mixins/test_columns_delegate.py @@ -0,0 +1,62 @@ +import time # For time travel queries. + +from piccolo.query.mixins import ColumnsDelegate +from tests.base import DBTestCase, engines_only +from tests.example_apps.music.tables import Band + + +class TestColumnsDelegate(DBTestCase): + def test_list_unpacking(self): + """ + The ``ColumnsDelegate`` should unpack a list of columns if passed in by + mistake, without the user unpacking them explicitly. + + .. code-block:: python + + # These two should both work the same: + await Band.select([Band.id, Band.name]).run() + await Band.select(Band.id, Band.name).run() + + """ + columns_delegate = ColumnsDelegate() + + columns_delegate.columns([Band.name]) + self.assertEqual(columns_delegate.selected_columns, [Band.name]) + + columns_delegate.columns([Band.id]) + self.assertEqual( + columns_delegate.selected_columns, [Band.name, Band.id] + ) + + @engines_only("cockroach") + def test_as_of(self): + """ + Time travel queries using "As Of" syntax. + Currently supports Cockroach using AS OF SYSTEM TIME. + """ + self.insert_rows() + time.sleep(1) # Ensure time travel queries have some history to use! + + query = ( + Band.select() + .where(Band.name == "Pythonistas") + .as_of("-500ms") + .limit(1) + ) + self.assertTrue("AS OF SYSTEM TIME '-500ms'" in str(query)) + result = query.run_sync() + + self.assertTrue(result[0]["name"] == "Pythonistas") + + query = Band.select().as_of() + self.assertTrue("AS OF SYSTEM TIME '-1s'" in str(query)) + result = query.run_sync() + + self.assertTrue(result[0]["name"] == "Pythonistas") + + # Alternative syntax. + query = Band.objects().get(Band.name == "Pythonistas").as_of("-1s") + self.assertTrue("AS OF SYSTEM TIME '-1s'" in str(query)) + result = query.run_sync() + + self.assertTrue(result.name == "Pythonistas") # type: ignore diff --git a/tests/query/mixins/test_order_by_delegate.py b/tests/query/mixins/test_order_by_delegate.py new file mode 100644 index 000000000..7d2f2c6c4 --- /dev/null +++ b/tests/query/mixins/test_order_by_delegate.py @@ -0,0 +1,19 @@ +from unittest import TestCase + +from piccolo.query.mixins import OrderByDelegate + + +class TestOrderByDelegate(TestCase): + def test_no_columns(self): + """ + An exception should be raised if no columns are passed in. + """ + delegate = OrderByDelegate() + + with self.assertRaises(ValueError) as manager: + delegate.order_by() + + self.assertEqual( + manager.exception.__str__(), + "At least one column must be passed to order_by.", + ) diff --git a/tests/query/operators/__init__.py b/tests/query/operators/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/query/operators/test_json.py b/tests/query/operators/test_json.py new file mode 100644 index 000000000..d7840ef9b --- /dev/null +++ b/tests/query/operators/test_json.py @@ -0,0 +1,52 @@ +from unittest import TestCase + +from piccolo.columns import JSONB +from piccolo.query.operators.json import GetChildElement, GetElementFromPath +from piccolo.table import Table +from tests.base import engines_skip + + +class RecordingStudio(Table): + facilities = JSONB(null=True) + + +@engines_skip("sqlite") +class TestGetChildElement(TestCase): + + def test_query(self): + """ + Make sure the generated SQL looks correct. + """ + querystring = GetChildElement( + GetChildElement(RecordingStudio.facilities, "a"), "b" + ) + + sql, query_args = querystring.compile_string() + + self.assertEqual( + sql, + '"recording_studio"."facilities" -> $1 -> $2', + ) + + self.assertListEqual(query_args, ["a", "b"]) + + +@engines_skip("sqlite") +class TestGetElementFromPath(TestCase): + + def test_query(self): + """ + Make sure the generated SQL looks correct. + """ + querystring = GetElementFromPath( + RecordingStudio.facilities, ["a", "b"] + ) + + sql, query_args = querystring.compile_string() + + self.assertEqual( + sql, + '"recording_studio"."facilities" #> $1', + ) + + self.assertListEqual(query_args, [["a", "b"]]) diff --git a/tests/query/test_await.py b/tests/query/test_await.py index 867ba7b98..e34abaef7 100644 --- a/tests/query/test_await.py +++ b/tests/query/test_await.py @@ -1,7 +1,7 @@ import asyncio -from ..base import DBTestCase -from ..example_app.tables import Band +from tests.base import DBTestCase +from tests.example_apps.music.tables import Band class TestAwait(DBTestCase): diff --git a/tests/query/test_camelcase.py b/tests/query/test_camelcase.py new file mode 100644 index 000000000..3cbc6cf04 --- /dev/null +++ b/tests/query/test_camelcase.py @@ -0,0 +1,61 @@ +from unittest import TestCase + +from piccolo.columns import ForeignKey, Varchar +from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync + + +class Manager(Table): + theName = Varchar() + + +class Band(Table): + theName = Varchar() + theManager = ForeignKey(Manager) + + +class TestCamelCase(TestCase): + def setUp(self): + create_db_tables_sync(Manager, Band) + + def tearDown(self): + drop_db_tables_sync(Manager, Band) + + def test_queries(self): + """ + Make sure that basic queries work when the columns use camelCase. + """ + manager_names = ("Guido", "Maz", "Graydon") + band_names = ("Pythonistas", "Rubyists", "Rustaceans") + + # Test create + for manager_name, band_name in zip(manager_names, band_names): + manager = Manager.objects().create(theName=manager_name).run_sync() + Band.objects().create( + theName=band_name, theManager=manager + ).run_sync() + + # Test select, with joins + response = ( + Band.select( + Band.theName, + Band.theManager.theName.as_alias("theManagerName"), + ) + .order_by(Band.theName) + .run_sync() + ) + self.assertListEqual( + response, + [ + {"theName": "Pythonistas", "theManagerName": "Guido"}, + {"theName": "Rubyists", "theManagerName": "Maz"}, + {"theName": "Rustaceans", "theManagerName": "Graydon"}, + ], + ) + + # Test delete + Band.delete().where(Band.theName == "Rubyists").run_sync() + + # Test exists + self.assertFalse( + Band.exists().where(Band.theName == "Rubyists").run_sync() + ) diff --git a/tests/query/test_freeze.py b/tests/query/test_freeze.py index 7c5d3e744..61a54e761 100644 --- a/tests/query/test_freeze.py +++ b/tests/query/test_freeze.py @@ -1,26 +1,29 @@ import timeit -import typing as t from dataclasses import dataclass +from typing import Any, Union +from unittest import mock -from piccolo.query.base import Query -from tests.base import DBTestCase, sqlite_only -from tests.example_app.tables import Band +from piccolo.columns import Integer, Varchar +from piccolo.query.base import FrozenQuery, Query +from piccolo.table import Table +from tests.base import AsyncMock, DBTestCase, sqlite_only +from tests.example_apps.music.tables import Band @dataclass class QueryResponse: - query: Query - response: t.Any + query: Union[Query, FrozenQuery] + response: Any class TestFreeze(DBTestCase): - def test_frozen_select_queries(self): + def test_frozen_select_queries(self) -> None: """ Make sure a variety of select queries work as expected when frozen. """ self.insert_rows() - query_responses: t.List[QueryResponse] = [ + query_responses: list[QueryResponse] = [ QueryResponse( query=( Band.select(Band.name) @@ -79,12 +82,26 @@ def test_frozen_performance(self): The frozen query performance should exceed the non-frozen. If not, there's a problem. - Only test this on SQLite, as the latency from the database itself - is more predictable than with Postgres, and the test runs quickly. + We mock out the database to make the performance more predictable. """ + db = mock.MagicMock() + db.engine_type = "sqlite" + db.run_querystring = AsyncMock() + db.run_querystring.return_value = [ + {"name": "Pythonistas", "popularity": 1000} + ] + + class Band(Table, db=db): + name = Varchar() + popularity = Integer() + iterations = 50 - query = Band.select().where(Band.name == "Pythonistas").first() + query = ( + Band.select(Band.name) + .where(Band.popularity > 900) + .order_by(Band.name) + ) query_duration = timeit.repeat( lambda: query.run_sync(), repeat=iterations, number=1 ) @@ -95,9 +112,9 @@ def test_frozen_performance(self): ) # Remove the outliers before comparing - self.assertTrue( - sum(sorted(query_duration)[5:-5]) - > sum(sorted(frozen_query_duration)[5:-5]) + self.assertGreater( + sum(sorted(query_duration)[10:-10]), + sum(sorted(frozen_query_duration)[10:-10]), ) def test_attribute_access(self): diff --git a/tests/query/test_gather.py b/tests/query/test_gather.py new file mode 100644 index 000000000..548efa9b5 --- /dev/null +++ b/tests/query/test_gather.py @@ -0,0 +1,24 @@ +import asyncio + +from tests.base import DBTestCase +from tests.example_apps.music.tables import Manager + + +class TestAwait(DBTestCase): + def test_await(self): + """ + Make sure that asyncio.gather works with the main query types. + """ + + async def run_queries(): + return await asyncio.gather( + Manager.select(), + Manager.insert(Manager(name="Golangs")), + Manager.delete().where(Manager.name != "Golangs"), + Manager.objects(), + Manager.count(), + Manager.raw("SELECT * FROM manager"), + ) + + # No exceptions should be raised. + self.assertIsInstance(asyncio.run(run_queries()), list) diff --git a/tests/query/test_querystring.py b/tests/query/test_querystring.py index 59631ab7f..58ca29495 100644 --- a/tests/query/test_querystring.py +++ b/tests/query/test_querystring.py @@ -1,6 +1,7 @@ from unittest import TestCase from piccolo.querystring import QueryString +from tests.base import postgres_only # TODO - add more extensive tests (increased nesting and argument count). @@ -28,3 +29,138 @@ def test_string(self): def test_querystring_with_no_args(self): qs = QueryString("SELECT name FROM band") self.assertEqual(qs.compile_string(), ("SELECT name FROM band", [])) + + +@postgres_only +class TestQueryStringOperators(TestCase): + """ + Make sure basic operations can be used on ``QueryString``. + """ + + def test_add(self): + query = QueryString("SELECT price") + 1 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(), + ("SELECT price + $1", [1]), + ) + + def test_multiply(self): + query = QueryString("SELECT price") * 2 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(), + ("SELECT price * $1", [2]), + ) + + def test_divide(self): + query = QueryString("SELECT price") / 1 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(), + ("SELECT price / $1", [1]), + ) + + def test_power(self): + query = QueryString("SELECT price") ** 2 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(), + ("SELECT price ^ $1", [2]), + ) + + def test_subtract(self): + query = QueryString("SELECT price") - 1 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(), + ("SELECT price - $1", [1]), + ) + + def test_modulus(self): + query = QueryString("SELECT price") % 1 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(), + ("SELECT price % $1", [1]), + ) + + def test_like(self): + query = QueryString("strip(name)").like("Python%") + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(), + ("strip(name) LIKE $1", ["Python%"]), + ) + + def test_ilike(self): + query = QueryString("strip(name)").ilike("Python%") + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(), + ("strip(name) ILIKE $1", ["Python%"]), + ) + + def test_greater_than(self): + query = QueryString("SELECT price") > 10 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(), + ("SELECT price > $1", [10]), + ) + + def test_greater_equal_than(self): + query = QueryString("SELECT price") >= 10 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(), + ("SELECT price >= $1", [10]), + ) + + def test_less_than(self): + query = QueryString("SELECT price") < 10 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(), + ("SELECT price < $1", [10]), + ) + + def test_less_equal_than(self): + query = QueryString("SELECT price") <= 10 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(), + ("SELECT price <= $1", [10]), + ) + + def test_equals(self): + query = QueryString("SELECT price") == 10 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(), + ("SELECT price = $1", [10]), + ) + + def test_not_equals(self): + query = QueryString("SELECT price") != 10 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(), + ("SELECT price != $1", [10]), + ) + + def test_is_in(self): + query = QueryString("SELECT price").is_in([10, 20, 30]) + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(), + ("SELECT price IN $1", [[10, 20, 30]]), + ) + + def test_not_in(self): + query = QueryString("SELECT price").not_in([10, 20, 30]) + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(), + ("SELECT price NOT IN $1", [[10, 20, 30]]), + ) diff --git a/tests/query/test_slots.py b/tests/query/test_slots.py index 7777f4158..ad6322502 100644 --- a/tests/query/test_slots.py +++ b/tests/query/test_slots.py @@ -13,8 +13,7 @@ TableExists, Update, ) - -from ..example_app.tables import Manager +from tests.example_apps.music.tables import Manager class TestSlots(TestCase): @@ -42,4 +41,4 @@ def test_attributes(self): AttributeError, msg=f"{class_name} didn't raised an error" ): print(f"Setting {class_name} attribute") - query_class(table=Manager).abc = 123 + query_class(table=Manager).abc = 123 # type: ignore diff --git a/tests/sqlite_conf.py b/tests/sqlite_conf.py index 6564b4b94..afcc6185e 100644 --- a/tests/sqlite_conf.py +++ b/tests/sqlite_conf.py @@ -4,4 +4,9 @@ DB = SQLiteEngine(path="test.sqlite") -APP_REGISTRY = AppRegistry(apps=["tests.example_app.piccolo_app"]) +APP_REGISTRY = AppRegistry( + apps=[ + "tests.example_apps.music.piccolo_app", + "tests.example_apps.mega.piccolo_app", + ] +) diff --git a/tests/table/instance/test_create.py b/tests/table/instance/test_create.py new file mode 100644 index 000000000..6e4856cc2 --- /dev/null +++ b/tests/table/instance/test_create.py @@ -0,0 +1,39 @@ +from unittest import TestCase + +from piccolo.columns import Integer, Varchar +from piccolo.table import Table + + +class Band(Table): + name = Varchar(default=None, null=False) + popularity = Integer() + + +class TestCreate(TestCase): + def setUp(self): + Band.create_table().run_sync() + + def tearDown(self): + Band.alter().drop_table().run_sync() + + def test_create_new(self): + """ + Make sure that creating a new instance works. + """ + Band.objects().create(name="Pythonistas", popularity=1000).run_sync() + + names = [i["name"] for i in Band.select(Band.name).run_sync()] + self.assertTrue("Pythonistas" in names) + + def test_null_values(self): + """ + Make sure we test non-null columns: + https://github.com/piccolo-orm/piccolo/issues/652 + """ + with self.assertRaises(ValueError) as manager: + Band.objects().create().run_sync() + + self.assertEqual(str(manager.exception), "name wasn't provided") + + # Shouldn't raise an exception + Band.objects().create(name="Pythonistas").run_sync() diff --git a/tests/table/instance/test_equality.py b/tests/table/instance/test_equality.py new file mode 100644 index 000000000..40ae59517 --- /dev/null +++ b/tests/table/instance/test_equality.py @@ -0,0 +1,74 @@ +from piccolo.columns.column_types import UUID, Varchar +from piccolo.table import Table +from piccolo.testing.test_case import AsyncTableTest +from tests.example_apps.music.tables import Manager + + +class ManagerUUID(Table): + id = UUID(primary_key=True) + name = Varchar() + + +class TestInstanceEquality(AsyncTableTest): + tables = [ + Manager, + ManagerUUID, + ] + + async def test_instance_equality(self) -> None: + """ + Make sure instance equality works, for tables with a `Serial` primary + key. + """ + manager_1 = Manager(name="Guido") + await manager_1.save() + + manager_2 = Manager(name="Graydon") + await manager_2.save() + + self.assertEqual(manager_1, manager_1) + self.assertNotEqual(manager_1, manager_2) + + # Try fetching the row from the database. + manager_1_from_db = ( + await Manager.objects().where(Manager.id == manager_1.id).first() + ) + self.assertEqual(manager_1, manager_1_from_db) + self.assertNotEqual(manager_2, manager_1_from_db) + + # Try rows which haven't been saved yet. + # They have no primary key value (because they use Serial columns + # as the primary key), so they shouldn't be equal. + self.assertNotEqual(Manager(), Manager()) + self.assertNotEqual(manager_1, Manager()) + + # Make sure an object is equal to itself, even if not saved. + manager_unsaved = Manager() + self.assertEqual(manager_unsaved, manager_unsaved) + + async def test_instance_equality_uuid(self) -> None: + """ + Make sure instance equality works, for tables with a `UUID` primary + key. + """ + manager_1 = ManagerUUID(name="Guido") + await manager_1.save() + + manager_2 = ManagerUUID(name="Graydon") + await manager_2.save() + + self.assertEqual(manager_1, manager_1) + self.assertNotEqual(manager_1, manager_2) + + # Try fetching the row from the database. + manager_1_from_db = ( + await ManagerUUID.objects() + .where(ManagerUUID.id == manager_1.id) + .first() + ) + self.assertEqual(manager_1, manager_1_from_db) + self.assertNotEqual(manager_2, manager_1_from_db) + + # Make sure an object is equal to itself, even if not saved. + manager_unsaved = ManagerUUID() + self.assertEqual(manager_unsaved, manager_unsaved) diff --git a/tests/table/instance/test_get_related.py b/tests/table/instance/test_get_related.py index 6e5025602..b662f54a0 100644 --- a/tests/table/instance/test_get_related.py +++ b/tests/table/instance/test_get_related.py @@ -1,29 +1,98 @@ -from unittest import TestCase +from typing import cast -from tests.example_app.tables import Band, Manager +from piccolo.testing.test_case import AsyncTableTest +from tests.example_apps.music.tables import Band, Concert, Manager, Venue -TABLES = [Manager, Band] +class TestGetRelated(AsyncTableTest): + tables = [Manager, Band, Concert, Venue] -class TestGetRelated(TestCase): - def setUp(self): - for table in TABLES: - table.create_table().run_sync() + async def asyncSetUp(self): + await super().asyncSetUp() - def tearDown(self): - for table in reversed(TABLES): - table.alter().drop_table().run_sync() + # Setup two pairs of manager/band, so we can make sure the correct + # objects are returned. - def test_get_related(self): + self.manager = Manager(name="Guido") + await self.manager.save() + + self.band = Band( + name="Pythonistas", manager=self.manager.id, popularity=100 + ) + await self.band.save() + + self.manager_2 = Manager(name="Graydon") + await self.manager_2.save() + + self.band_2 = Band( + name="Rustaceans", manager=self.manager_2.id, popularity=100 + ) + await self.band_2.save() + + async def test_foreign_key(self) -> None: """ Make sure you can get a related object from another object instance. """ - manager = Manager(name="Guido") - manager.save().run_sync() + manager = await self.band.get_related(Band.manager) + assert manager is not None + self.assertTrue(manager.id == self.manager.id) + + manager_2 = await self.band_2.get_related(Band.manager) + assert manager_2 is not None + self.assertTrue(manager_2.id == self.manager_2.id) + + async def test_non_foreign_key(self): + """ + Make sure that non-ForeignKey raise an exception. + """ + with self.assertRaises(ValueError): + self.band.get_related(Band.name) # type: ignore + + async def test_string(self): + """ + Make sure it also works using a string representation of a foreign key. + """ + manager = cast(Manager, await self.band.get_related("manager")) + self.assertTrue(manager.id == self.manager.id) + + async def test_invalid_string(self): + """ + Make sure an exception is raised if the foreign key string is invalid. + """ + with self.assertRaises(ValueError): + self.band.get_related("abc123") + + async def test_multiple_levels(self): + """ + Make sure ``get_related`` works multiple levels deep. + """ + concert = Concert(band_1=self.band, band_2=self.band_2) + await concert.save() + + manager = await concert.get_related(Concert.band_1._.manager) + assert manager is not None + self.assertTrue(manager.id == self.manager.id) + + manager_2 = await concert.get_related(Concert.band_2._.manager) + assert manager_2 is not None + self.assertTrue(manager_2.id == self.manager_2.id) + + async def test_no_match(self): + """ + If not related object exists, make sure ``None`` is returned. + """ + concert = Concert(band_1=self.band, band_2=None) + await concert.save() - band = Band(name="Pythonistas", manager=manager.id, popularity=100) - band.save().run_sync() + manager_2 = await concert.get_related(Concert.band_2._.manager) + assert manager_2 is None - _manager = band.get_related(Band.manager).run_sync() + async def test_not_in_db(self): + """ + If the object we're calling ``get_related`` on doesn't exist in the + database, then make sure an error is raised. + """ + concert = Concert(band_1=self.band, band_2=self.band_2) - self.assertTrue(_manager.name == "Guido") + with self.assertRaises(ValueError): + await concert.get_related(Concert.band_1._.manager) diff --git a/tests/table/instance/test_get_related_readable.py b/tests/table/instance/test_get_related_readable.py index 649b10099..982c4a5bc 100644 --- a/tests/table/instance/test_get_related_readable.py +++ b/tests/table/instance/test_get_related_readable.py @@ -1,22 +1,168 @@ -from tests.base import DBTestCase -from tests.example_app.tables import Band +import decimal +from unittest import TestCase +from piccolo.columns import ForeignKey, Varchar +from piccolo.columns.readable import Readable +from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync +from tests.base import engine_is +from tests.example_apps.music.tables import ( + Band, + Concert, + Manager, + Ticket, + Venue, +) + + +class ThingOne(Table): + name = Varchar(length=300, null=False) + + +class ThingTwo(Table): + name = Varchar(length=300, null=False) + thing_one = ForeignKey(references=ThingOne) + + +class ThingThree(Table): + name = Varchar(length=300, null=False) + thing_two = ForeignKey(references=ThingTwo) + + @classmethod + def get_readable(cls): + return Readable( + template="three name: %s - two name: %s - one name: %s", + columns=[ + cls.name, + cls.thing_two.name, + cls.thing_two._.thing_one._.name, + ], + ) + + +class ThingFour(Table): + name = Varchar(length=300, null=False) + thing_three = ForeignKey(references=ThingThree) + + +TABLES = [ + Band, + Concert, + Manager, + Venue, + Ticket, + ThingOne, + ThingTwo, + ThingThree, + ThingFour, +] + + +class TestGetRelatedReadable(TestCase): + def setUp(self): + create_db_tables_sync(*TABLES) + + manager_1 = Manager.objects().create(name="Guido").run_sync() + manager_2 = Manager.objects().create(name="Graydon").run_sync() + + band_1 = ( + Band.objects() + .create(name="Pythonistas", manager=manager_1) + .run_sync() + ) + band_2 = ( + Band.objects() + .create(name="Rustaceans", manager=manager_2) + .run_sync() + ) + venue = ( + Venue.objects() + .create(name="Royal Albert Hall", capacity=5900) + .run_sync() + ) + concert = ( + Concert.objects() + .create(venue=venue, band_1=band_1, band_2=band_2) + .run_sync() + ) + Ticket.objects().create( + price=decimal.Decimal(50.0), concert=concert + ).run_sync() + + thing_one = ThingOne.insert(ThingOne(name="thing_one")).run_sync() + thing_two = ThingTwo.insert( + ThingTwo(name="thing_two", thing_one=thing_one[0]["id"]) + ).run_sync() + thing_three = ThingThree.insert( + ThingThree(name="thing_three", thing_two=thing_two[0]["id"]) + ).run_sync() + ThingFour.insert( + ThingFour(name="thing_four", thing_three=thing_three[0]["id"]) + ).run_sync() + + def tearDown(self): + drop_db_tables_sync(*TABLES) -class TestGetRelatedReadable(DBTestCase): def test_get_related_readable(self): """ Make sure you can get the `Readable` representation for related object from another object instance. """ - self.insert_row() - response = Band.select( Band.name, Band._get_related_readable(Band.manager) ).run_sync() self.assertEqual( - response, [{"name": "Pythonistas", "manager_readable": "Guido"}] + response, + [ + {"name": "Pythonistas", "manager_readable": "Guido"}, + {"manager_readable": "Graydon", "name": "Rustaceans"}, + ], ) - # TODO Need to make sure it can go two levels deep ... - # e.g. Concert._get_related_readable(Concert.band_1.manager) + # Now try something much more complex. + response = Ticket.select( + Ticket.id, + Ticket._get_related_readable(Ticket.concert), + ).run_sync() + + if engine_is("cockroach"): + self.assertEqual( + response, + [ + { + "id": response[0]["id"], + "concert_readable": ( + "Pythonistas and Rustaceans at Royal Albert Hall, " + "capacity 5900" + ), + } + ], + ) + else: + self.assertEqual( + response, + [ + { + "id": 1, + "concert_readable": ( + "Pythonistas and Rustaceans at Royal Albert Hall, " + "capacity 5900" + ), + } + ], + ) + + # A really complex references chain from Piccolo Admin issue #170 + response = ThingFour.select( + ThingFour._get_related_readable(ThingFour.thing_three) + ).run_sync() + self.assertEqual( + response, + [ + { + "thing_three_readable": ( + "three name: thing_three - two name: thing_two - one name: thing_one" # noqa: E501 + ) + } + ], + ) diff --git a/tests/table/instance/test_instantiate.py b/tests/table/instance/test_instantiate.py index 154b18df1..6fceaa2be 100644 --- a/tests/table/instance/test_instantiate.py +++ b/tests/table/instance/test_instantiate.py @@ -1,5 +1,5 @@ -from tests.base import DBTestCase, postgres_only, sqlite_only -from tests.example_app.tables import Band +from tests.base import DBTestCase, engines_only, sqlite_only +from tests.example_apps.music.tables import Band class TestInstance(DBTestCase): @@ -7,13 +7,20 @@ class TestInstance(DBTestCase): Test instantiating Table instances """ - @postgres_only + @engines_only("postgres") def test_insert_postgres(self): Pythonistas = Band(name="Pythonistas") self.assertEqual( Pythonistas.__str__(), "(DEFAULT,'Pythonistas',null,0)" ) + @engines_only("cockroach") + def test_insert_postgres_alt(self): + Pythonistas = Band(name="Pythonistas") + self.assertEqual( + Pythonistas.__str__(), "(unique_rowid(),'Pythonistas',null,0)" + ) + @sqlite_only def test_insert_sqlite(self): Pythonistas = Band(name="Pythonistas") diff --git a/tests/table/instance/test_remove.py b/tests/table/instance/test_remove.py index 18045f2ab..494b4fc89 100644 --- a/tests/table/instance/test_remove.py +++ b/tests/table/instance/test_remove.py @@ -1,6 +1,6 @@ from unittest import TestCase -from tests.example_app.tables import Manager +from tests.example_apps.music.tables import Manager class TestRemove(TestCase): @@ -17,9 +17,11 @@ def test_remove(self): "Maz" in Manager.select(Manager.name).output(as_list=True).run_sync() ) + self.assertEqual(manager._exists_in_db, True) manager.remove().run_sync() self.assertTrue( "Maz" not in Manager.select(Manager.name).output(as_list=True).run_sync() ) + self.assertEqual(manager._exists_in_db, False) diff --git a/tests/table/instance/test_save.py b/tests/table/instance/test_save.py index f5c17b73f..67a27ee85 100644 --- a/tests/table/instance/test_save.py +++ b/tests/table/instance/test_save.py @@ -1,14 +1,16 @@ from unittest import TestCase -from tests.example_app.tables import Manager +from piccolo.table import create_db_tables_sync, drop_db_tables_sync +from tests.base import engines_only, engines_skip +from tests.example_apps.music.tables import Band, Manager class TestSave(TestCase): def setUp(self): - Manager.create_table().run_sync() + create_db_tables_sync(Manager, Band) def tearDown(self): - Manager.alter().drop_table().run_sync() + drop_db_tables_sync(Manager, Band) def test_save_new(self): """ @@ -34,3 +36,119 @@ def test_save_new(self): names = [i["name"] for i in Manager.select(Manager.name).run_sync()] self.assertTrue("Maz2" in names) self.assertTrue("Maz" not in names) + + @engines_skip("cockroach") + def test_save_specific_columns(self): + """ + Make sure that we can save a subset of columns. + """ + manager = Manager(name="Guido") + manager.save().run_sync() + + band = Band(name="Pythonistas", popularity=1000, manager=manager) + band.save().run_sync() + + self.assertEqual( + Band.select().run_sync(), + [ + { + "id": 1, + "name": "Pythonistas", + "manager": 1, + "popularity": 1000, + } + ], + ) + + band.name = "Pythonistas 2" + band.popularity = 2000 + band.save(columns=[Band.name]).run_sync() + + # Only the name should update, and not the popularity: + self.assertEqual( + Band.select().run_sync(), + [ + { + "id": 1, + "name": "Pythonistas 2", + "manager": 1, + "popularity": 1000, + } + ], + ) + + # Also test it using strings to identify columns + band.name = "Pythonistas 3" + band.popularity = 3000 + band.save(columns=["popularity"]).run_sync() + + # Only the popularity should update, and not the name: + self.assertEqual( + Band.select().run_sync(), + [ + { + "id": 1, + "name": "Pythonistas 2", + "manager": 1, + "popularity": 3000, + } + ], + ) + + @engines_only("cockroach") + def test_save_specific_columns_alt(self): + """ + Make sure that we can save a subset of columns. + """ + manager = Manager(name="Guido") + manager.save().run_sync() + + band = Band(name="Pythonistas", popularity=1000, manager=manager) + band.save().run_sync() + + self.assertEqual( + Band.select().run_sync(), + [ + { + "id": band.id, + "name": "Pythonistas", + "manager": band.manager.id, + "popularity": 1000, + } + ], + ) + + band.name = "Pythonistas 2" + band.popularity = 2000 + band.save(columns=[Band.name]).run_sync() + + # Only the name should update, and not the popularity: + self.assertEqual( + Band.select().run_sync(), + [ + { + "id": band.id, + "name": "Pythonistas 2", + "manager": band.manager.id, + "popularity": 1000, + } + ], + ) + + # Also test it using strings to identify columns + band.name = "Pythonistas 3" + band.popularity = 3000 + band.save(columns=["popularity"]).run_sync() + + # Only the popularity should update, and not the name: + self.assertEqual( + Band.select().run_sync(), + [ + { + "id": band.id, + "name": "Pythonistas 2", + "manager": band.manager.id, + "popularity": 3000, + } + ], + ) diff --git a/tests/table/instance/test_to_dict.py b/tests/table/instance/test_to_dict.py new file mode 100644 index 000000000..b5d75f52b --- /dev/null +++ b/tests/table/instance/test_to_dict.py @@ -0,0 +1,110 @@ +from tests.base import DBTestCase, engine_is +from tests.example_apps.music.tables import Band, Manager + + +class TestToDict(DBTestCase): + def test_to_dict(self): + """ + Make sure that `to_dict` works correctly. + """ + self.insert_row() + + instance = Manager.objects().first().run_sync() + assert instance is not None + dictionary = instance.to_dict() + if engine_is("cockroach"): + self.assertDictEqual( + dictionary, {"id": dictionary["id"], "name": "Guido"} + ) + else: + self.assertDictEqual(dictionary, {"id": 1, "name": "Guido"}) + + def test_nested(self): + """ + Make sure that `to_dict` works correctly, when the object contains + nested objects. + """ + self.insert_row() + + instance = Band.objects(Band.manager).first().run_sync() + assert instance is not None + dictionary = instance.to_dict() + if engine_is("cockroach"): + self.assertDictEqual( + dictionary, + { + "id": dictionary["id"], + "name": "Pythonistas", + "manager": { + "id": instance["manager"]["id"], + "name": "Guido", + }, + "popularity": 1000, + }, + ) + else: + self.assertDictEqual( + dictionary, + { + "id": 1, + "name": "Pythonistas", + "manager": {"id": 1, "name": "Guido"}, + "popularity": 1000, + }, + ) + + def test_filter_rows(self): + """ + Make sure that `to_dict` works correctly with a subset of columns. + """ + self.insert_row() + + instance = Manager.objects().first().run_sync() + assert instance is not None + dictionary = instance.to_dict(Manager.name) + self.assertDictEqual(dictionary, {"name": "Guido"}) + + def test_nested_filter(self): + """ + Make sure that `to_dict` works correctly with nested objects and + filtering. + """ + self.insert_row() + + instance = Band.objects(Band.manager).first().run_sync() + assert instance is not None + dictionary = instance.to_dict(Band.name, Band.manager.id) + if engine_is("cockroach"): + self.assertDictEqual( + dictionary, + { + "name": "Pythonistas", + "manager": {"id": dictionary["manager"]["id"]}, + }, + ) + else: + self.assertDictEqual( + dictionary, + { + "name": "Pythonistas", + "manager": {"id": 1}, + }, + ) + + def test_aliases(self): + """ + Make sure that `to_dict` works correctly with aliases. + """ + self.insert_row() + + instance = Manager.objects().first().run_sync() + assert instance is not None + dictionary = instance.to_dict( + Manager.id, Manager.name.as_alias("title") + ) + if engine_is("cockroach"): + self.assertDictEqual( + dictionary, {"id": dictionary["id"], "title": "Guido"} + ) + else: + self.assertDictEqual(dictionary, {"id": 1, "title": "Guido"}) diff --git a/tests/table/test_all_columns.py b/tests/table/test_all_columns.py new file mode 100644 index 000000000..cdbc0d98a --- /dev/null +++ b/tests/table/test_all_columns.py @@ -0,0 +1,23 @@ +from unittest import TestCase + +from tests.example_apps.music.tables import Band + + +class TestAllColumns(TestCase): + def test_all_columns(self): + self.assertEqual( + Band.all_columns(), + [Band.id, Band.name, Band.manager, Band.popularity], + ) + self.assertEqual(Band.all_columns(), Band._meta.columns) + + def test_all_columns_excluding(self): + self.assertEqual( + Band.all_columns(exclude=[Band.id]), + [Band.name, Band.manager, Band.popularity], + ) + + self.assertEqual( + Band.all_columns(exclude=["id"]), + [Band.name, Band.manager, Band.popularity], + ) diff --git a/tests/table/test_alter.py b/tests/table/test_alter.py index a8cfa0387..32057b9f0 100644 --- a/tests/table/test_alter.py +++ b/tests/table/test_alter.py @@ -1,33 +1,76 @@ +from __future__ import annotations + +from typing import Any, Union from unittest import TestCase +import pytest + from piccolo.columns import BigInt, Integer, Numeric, Varchar +from piccolo.columns.base import Column +from piccolo.columns.column_types import ForeignKey, Text +from piccolo.schema import SchemaManager from piccolo.table import Table - -from ..base import DBTestCase, postgres_only -from ..example_app.tables import Band, Manager - - +from tests.base import ( + DBTestCase, + engine_version_lt, + engines_only, + is_running_sqlite, +) +from tests.example_apps.music.tables import Band, Manager + + +@pytest.mark.skipif( + is_running_sqlite() and engine_version_lt(3.25), + reason="SQLite version not supported", +) class TestRenameColumn(DBTestCase): - def _test_rename(self, column): + def _test_rename( + self, + existing_column: Union[Column, str], + new_column_name: str = "rating", + ): self.insert_row() - rename_query = Band.alter().rename_column(column, "rating") + rename_query = Band.alter().rename_column( + existing_column, new_column_name + ) rename_query.run_sync() select_query = Band.raw("SELECT * FROM band") response = select_query.run_sync() column_names = response[0].keys() + existing_column_name = ( + existing_column._meta.name + if isinstance(existing_column, Column) + else existing_column + ) self.assertTrue( - ("rating" in column_names) and ("popularity" not in column_names) + (new_column_name in column_names) + and (existing_column_name not in column_names) ) - def test_rename_string(self): + def test_column(self): + """ + Make sure a ``Column`` argument works. + """ self._test_rename(Band.popularity) - def test_rename_column(self): + def test_string(self): + """ + Make sure a string argument works. + """ self._test_rename("popularity") + def test_problematic_name(self): + """ + Make sure we can rename columns with names which clash with SQL + keywords. + """ + self._test_rename( + existing_column=Band.popularity, new_column_name="order" + ) + class TestRenameTable(DBTestCase): def test_rename(self): @@ -39,7 +82,7 @@ def tearDown(self): self.run_sync("DROP TABLE IF EXISTS act") -@postgres_only +@engines_only("postgres", "cockroach") class TestDropColumn(DBTestCase): """ Unfortunately this only works with Postgres at the moment. @@ -47,7 +90,7 @@ class TestDropColumn(DBTestCase): SQLite has very limited support for ALTER statements. """ - def _test_drop(self, column: str): + def _test_drop(self, column: Union[str, Column]): self.insert_row() Band.alter().drop_column(column).run_sync() @@ -55,7 +98,7 @@ def _test_drop(self, column: str): response = Band.raw("SELECT * FROM band").run_sync() column_names = response[0].keys() - self.assertTrue("popularity" not in column_names) + self.assertNotIn("popularity", column_names) def test_drop_string(self): self._test_drop(Band.popularity) @@ -64,31 +107,60 @@ def test_drop_column(self): self._test_drop("popularity") -class TestAdd(DBTestCase): - def test_add(self): - """ - This needs a lot more work. Need to set values for existing rows. - - Just write the test for now ... - """ +class TestAddColumn(DBTestCase): + def _test_add_column( + self, column: Column, column_name: str, expected_value: Any + ): self.insert_row() - - add_query = Band.alter().add_column( - "weight", Integer(null=True, default=None) - ) - add_query.run_sync() + Band.alter().add_column(column_name, column).run_sync() response = Band.raw("SELECT * FROM band").run_sync() column_names = response[0].keys() - self.assertTrue("weight" in column_names) + self.assertIn(column_name, column_names) + + self.assertEqual(response[0][column_name], expected_value) + + def test_integer(self): + self._test_add_column( + column=Integer(null=True, default=None), + column_name="members", + expected_value=None, + ) + + def test_foreign_key(self): + self._test_add_column( + column=ForeignKey(references=Manager), + column_name="assistant_manager", + expected_value=None, + ) - self.assertEqual(response[0]["weight"], None) + def test_text(self): + bio = "An amazing band" + self._test_add_column( + column=Text(default=bio), + column_name="bio", + expected_value=bio, + ) + + def test_problematic_name(self): + """ + Make sure we can add columns with names which clash with SQL keywords. + """ + self._test_add_column( + column=Text(default="asc"), + column_name="order", + expected_value="asc", + ) -@postgres_only class TestUnique(DBTestCase): + @engines_only("postgres") def test_unique(self): + """ + Test altering a column uniqueness with MigrationManager. + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/42840 "unimplemented: cannot drop UNIQUE constraint "manager_name_key" using ALTER TABLE DROP CONSTRAINT, use DROP INDEX CASCADE instead" + """ # noqa: E501 unique_query = Manager.alter().set_unique(Manager.name, True) unique_query.run_sync() @@ -102,7 +174,7 @@ def test_unique(self): Manager(name="Bob").save().run_sync() response = Manager.select().run_sync() - self.assertTrue(len(response) == 2) + self.assertEqual(len(response), 2) # Now remove the constraint, and add a row. not_unique_query = Manager.alter().set_unique(Manager.name, False) @@ -113,7 +185,7 @@ def test_unique(self): self.assertTrue(len(response), 2) -@postgres_only +@engines_only("postgres", "cockroach") class TestMultiple(DBTestCase): """ Make sure multiple alter statements work correctly. @@ -132,12 +204,12 @@ def test_multiple(self): response = Band.raw("SELECT * FROM manager").run_sync() column_names = response[0].keys() - self.assertTrue("column_a" in column_names) - self.assertTrue("column_b" in column_names) + self.assertIn("column_a", column_names) + self.assertIn("column_b", column_names) # TODO - test more conversions. -@postgres_only +@engines_only("postgres", "cockroach") class TestSetColumnType(DBTestCase): def test_integer_to_bigint(self): """ @@ -157,10 +229,9 @@ def test_integer_to_bigint(self): "BIGINT", ) - popularity = ( - Band.select(Band.popularity).first().run_sync()["popularity"] - ) - self.assertEqual(popularity, 1000) + row = Band.select(Band.popularity).first().run_sync() + assert row is not None + self.assertEqual(row["popularity"], 1000) def test_integer_to_varchar(self): """ @@ -180,10 +251,9 @@ def test_integer_to_varchar(self): "CHARACTER VARYING", ) - popularity = ( - Band.select(Band.popularity).first().run_sync()["popularity"] - ) - self.assertEqual(popularity, "1000") + row = Band.select(Band.popularity).first().run_sync() + assert row is not None + self.assertEqual(row["popularity"], "1000") def test_using_expression(self): """ @@ -199,11 +269,12 @@ def test_using_expression(self): ) alter_query.run_sync() - popularity = Band.select(Band.name).first().run_sync()["name"] - self.assertEqual(popularity, 1) + row = Band.select(Band.name).first().run_sync() + assert row is not None + self.assertEqual(row["name"], 1) -@postgres_only +@engines_only("postgres", "cockroach") class TestSetNull(DBTestCase): def test_set_null(self): query = """ @@ -215,14 +286,14 @@ def test_set_null(self): Band.alter().set_null(Band.popularity, boolean=True).run_sync() response = Band.raw(query).run_sync() - self.assertTrue(response[0]["is_nullable"] == "YES") + self.assertEqual(response[0]["is_nullable"], "YES") Band.alter().set_null(Band.popularity, boolean=False).run_sync() response = Band.raw(query).run_sync() - self.assertTrue(response[0]["is_nullable"] == "NO") + self.assertEqual(response[0]["is_nullable"], "NO") -@postgres_only +@engines_only("postgres", "cockroach") class TestSetLength(DBTestCase): def test_set_length(self): query = """ @@ -235,10 +306,10 @@ def test_set_length(self): for length in (5, 20, 50): Band.alter().set_length(Band.name, length=length).run_sync() response = Band.raw(query).run_sync() - self.assertTrue(response[0]["character_maximum_length"] == length) + self.assertEqual(response[0]["character_maximum_length"], length) -@postgres_only +@engines_only("postgres", "cockroach") class TestSetDefault(DBTestCase): def test_set_default(self): Manager.alter().set_default(Manager.name, "Pending").run_sync() @@ -249,7 +320,66 @@ def test_set_default(self): ).run_sync() manager = Manager.objects().first().run_sync() - self.assertTrue(manager.name == "Pending") + assert manager is not None + self.assertEqual(manager.name, "Pending") + + +@engines_only("postgres", "cockroach") +class TestSetSchema(TestCase): + schema_manager = SchemaManager() + schema_name = "schema_1" + + def setUp(self): + Manager.create_table().run_sync() + self.schema_manager.create_schema( + schema_name=self.schema_name + ).run_sync() + + def tearDown(self): + Manager.alter().drop_table(if_exists=True).run_sync() + self.schema_manager.drop_schema( + schema_name=self.schema_name, cascade=True + ).run_sync() + + def test_set_schema(self): + Manager.alter().set_schema(schema_name=self.schema_name).run_sync() + + self.assertIn( + Manager._meta.tablename, + self.schema_manager.list_tables( + schema_name=self.schema_name + ).run_sync(), + ) + + +@engines_only("postgres", "cockroach") +class TestDropTable(TestCase): + class Manager(Table, schema="schema_1"): + pass + + schema_manager = SchemaManager() + + def tearDown(self): + self.schema_manager.drop_schema( + schema_name="schema_1", if_exists=True, cascade=True + ).run_sync() + + def test_drop_table_with_schema(self): + Manager = self.Manager + + Manager.create_table().run_sync() + + self.assertIn( + "manager", + self.schema_manager.list_tables(schema_name="schema_1").run_sync(), + ) + + Manager.alter().drop_table().run_sync() + + self.assertNotIn( + "manager", + self.schema_manager.list_tables(schema_name="schema_1").run_sync(), + ) ############################################################################### @@ -259,7 +389,6 @@ class Ticket(Table): price = Numeric(digits=(5, 2)) -@postgres_only class TestSetDigits(TestCase): def setUp(self): Ticket.create_table().run_sync() @@ -267,6 +396,7 @@ def setUp(self): def tearDown(self): Ticket.alter().drop_table().run_sync() + @engines_only("postgres") def test_set_digits(self): query = """ SELECT numeric_precision, numeric_scale @@ -280,10 +410,10 @@ def test_set_digits(self): column=Ticket.price, digits=(6, 2) ).run_sync() response = Ticket.raw(query).run_sync() - self.assertTrue(response[0]["numeric_precision"] == 6) - self.assertTrue(response[0]["numeric_scale"] == 2) + self.assertEqual(response[0]["numeric_precision"], 6) + self.assertEqual(response[0]["numeric_scale"], 2) Ticket.alter().set_digits(column=Ticket.price, digits=None).run_sync() response = Ticket.raw(query).run_sync() - self.assertTrue(response[0]["numeric_precision"] is None) - self.assertTrue(response[0]["numeric_scale"] is None) + self.assertIsNone(response[0]["numeric_precision"]) + self.assertIsNone(response[0]["numeric_scale"]) diff --git a/tests/table/test_batch.py b/tests/table/test_batch.py index a67c131d1..762701f53 100644 --- a/tests/table/test_batch.py +++ b/tests/table/test_batch.py @@ -1,8 +1,14 @@ import asyncio import math +from unittest import TestCase -from ..base import DBTestCase -from ..example_app.tables import Manager +from piccolo.columns import Varchar +from piccolo.engine.finder import engine_finder +from piccolo.engine.postgres import AsyncBatch, PostgresEngine +from piccolo.table import Table +from piccolo.utils.sync import run_sync +from tests.base import AsyncMock, DBTestCase, engines_only +from tests.example_apps.music.tables import Manager class TestBatchSelect(DBTestCase): @@ -10,12 +16,12 @@ def _check_results(self, batch): """ Make sure the data is returned in the correct format. """ - self.assertTrue(type(batch) == list) + self.assertEqual(type(batch), list) if len(batch) > 0: row = batch[0] - self.assertTrue(type(row) == dict) - self.assertTrue("name" in row.keys()) - self.assertTrue("id" in row.keys()) + self.assertEqual(type(row), dict) + self.assertIn("name", row.keys()) + self.assertIn("id", row.keys()) async def run_batch(self, batch_size): row_count = 0 @@ -44,8 +50,8 @@ def test_batch(self): _iterations = math.ceil(row_count / batch_size) - self.assertTrue(_row_count == row_count) - self.assertTrue(iterations == _iterations) + self.assertEqual(_row_count, row_count) + self.assertEqual(iterations, _iterations) class TestBatchObjects(DBTestCase): @@ -53,10 +59,10 @@ def _check_results(self, batch): """ Make sure the data is returned in the correct format. """ - self.assertTrue(type(batch) == list) + self.assertEqual(type(batch), list) if len(batch) > 0: row = batch[0] - self.assertTrue(isinstance(row, Manager)) + self.assertIsInstance(row, Manager) async def run_batch(self, batch_size): row_count = 0 @@ -85,5 +91,79 @@ def test_batch(self): _iterations = math.ceil(row_count / batch_size) - self.assertTrue(_row_count == row_count) - self.assertTrue(iterations == _iterations) + self.assertEqual(_row_count, row_count) + self.assertEqual(iterations, _iterations) + + +class TestBatchRaw(DBTestCase): + def _check_results(self, batch): + """ + Make sure the data is returned in the correct format. + """ + self.assertEqual(type(batch), list) + if len(batch) > 0: + row = batch[0] + self.assertIsInstance(row, Manager) + + async def run_batch(self, batch_size): + row_count = 0 + iterations = 0 + + async with await Manager.raw("SELECT * FROM manager").batch( + batch_size=batch_size + ) as batch: + async for _batch in batch: + self._check_results(_batch) + _row_count = len(_batch) + row_count += _row_count + iterations += 1 + + return row_count, iterations + + async def test_batch(self): + row_count = 1000 + self.insert_many_rows(row_count) + + batch_size = 10 + + _row_count, iterations = asyncio.run( + self.run_batch(batch_size=batch_size), debug=True + ) + + _iterations = math.ceil(row_count / batch_size) + + self.assertEqual(_row_count, row_count) + self.assertEqual(iterations, _iterations) + + +@engines_only("postgres", "cockroach") +class TestBatchNodeArg(TestCase): + def test_batch_extra_node(self): + """ + Make sure the batch methods can accept a node argument. + """ + + # Get the test database credentials: + test_engine = engine_finder() + assert isinstance(test_engine, PostgresEngine) + + EXTRA_NODE = AsyncMock(spec=PostgresEngine(config=test_engine.config)) + + DB = PostgresEngine( + config=test_engine.config, + extra_nodes={"read_1": EXTRA_NODE}, + ) + + class Manager(Table, db=DB): + name = Varchar() + + # Testing `select` + response = run_sync(Manager.select().batch(node="read_1")) + self.assertIsInstance(response, AsyncBatch) + self.assertTrue(EXTRA_NODE.get_new_connection.called) + EXTRA_NODE.reset_mock() + + # Testing `objects` + response = run_sync(Manager.objects().batch(node="read_1")) + self.assertIsInstance(response, AsyncBatch) + self.assertTrue(EXTRA_NODE.get_new_connection.called) diff --git a/tests/table/test_callback.py b/tests/table/test_callback.py new file mode 100644 index 000000000..46ea60f97 --- /dev/null +++ b/tests/table/test_callback.py @@ -0,0 +1,228 @@ +from unittest.mock import Mock + +from tests.base import AsyncMock, DBTestCase +from tests.example_apps.music.tables import Band + + +def identity(x): + """Returns the input. Used as the side effect for mock callbacks.""" + return x + + +def get_name(results): + return results["name"] + + +async def uppercase(name): + """Async to ensure coroutines are called correctly.""" + return name.upper() + + +def limit(name): + return name[:6] + + +class TestNoCallbackSelect(DBTestCase): + def test_no_callback(self): + """ + Just check we don't get any "NoneType is not callable" kind of errors + when we run a select query without setting any callbacks. + """ + self.insert_row() + Band.select(Band.name).run_sync() + + +class TestNoCallbackObjects(DBTestCase): + def test_no_callback(self): + """ + Just check we don't get any "NoneType is not callable" kind of errors + when we run an objects query without setting any callbacks. + """ + self.insert_row() + Band.objects().run_sync() + + +class TestCallbackSuccessesSelect(DBTestCase): + def test_callback_sync(self): + self.insert_row() + + callback_handler = Mock(return_value="it worked") + result = Band.select(Band.name).callback(callback_handler).run_sync() + callback_handler.assert_called_once_with([{"name": "Pythonistas"}]) + self.assertEqual(result, "it worked") + + def test_callback_async(self): + self.insert_row() + + callback_handler = AsyncMock(return_value="it worked") + result = Band.select(Band.name).callback(callback_handler).run_sync() + callback_handler.assert_called_once_with([{"name": "Pythonistas"}]) + self.assertEqual(result, "it worked") + + +class TestCallbackSuccessesObjects(DBTestCase): + def test_callback_sync(self): + self.insert_row() + + callback_handler = Mock(return_value="it worked") + result = Band.objects().callback(callback_handler).run_sync() + callback_handler.assert_called_once() + + args = callback_handler.call_args[0][0] + self.assertIsInstance(args, list) + self.assertIsInstance(args[0], Band) + self.assertEqual(args[0].name, "Pythonistas") + self.assertEqual(result, "it worked") + + def test_callback_async(self): + self.insert_row() + + callback_handler = AsyncMock(return_value="it worked") + result = Band.objects().callback(callback_handler).run_sync() + callback_handler.assert_called_once() + + args = callback_handler.call_args[0][0] + self.assertIsInstance(args, list) + self.assertIsInstance(args[0], Band) + self.assertEqual(args[0].name, "Pythonistas") + self.assertEqual(result, "it worked") + + +class TestMultipleCallbacksSelect(DBTestCase): + def test_all_sync(self): + self.insert_row() + + handlers = [ + Mock(side_effect=identity), + Mock(side_effect=identity), + Mock(side_effect=identity), + ] + Band.select(Band.name).callback(handlers).run_sync() + + for handler in handlers: + handler.assert_called_once_with([{"name": "Pythonistas"}]) + + def test_all_sync_chained(self): + self.insert_row() + + handlers = [ + Mock(side_effect=identity), + Mock(side_effect=identity), + Mock(side_effect=identity), + ] + + ( + Band.select(Band.name) + .callback(handlers[0]) + .callback(handlers[1]) + .callback(handlers[2]) + .run_sync() + ) + + for handler in handlers: + handler.assert_called_once_with([{"name": "Pythonistas"}]) + + def test_all_async(self): + self.insert_row() + + handlers = [ + AsyncMock(side_effect=identity), + AsyncMock(side_effect=identity), + AsyncMock(side_effect=identity), + ] + Band.select(Band.name).callback(handlers).run_sync() + + for handler in handlers: + handler.assert_called_once_with([{"name": "Pythonistas"}]) + + def test_all_async_chained(self): + self.insert_row() + + handlers = [ + AsyncMock(side_effect=identity), + AsyncMock(side_effect=identity), + AsyncMock(side_effect=identity), + ] + ( + Band.select(Band.name) + .callback(handlers[0]) + .callback(handlers[1]) + .callback(handlers[2]) + .run_sync() + ) + for handler in handlers: + handler.assert_called_once_with([{"name": "Pythonistas"}]) + + def test_mixed(self): + self.insert_row() + + handlers = [ + Mock(side_effect=identity), + AsyncMock(side_effect=identity), + Mock(side_effect=identity), + ] + Band.select(Band.name).callback(handlers).run_sync() + + for handler in handlers: + handler.assert_called_once_with([{"name": "Pythonistas"}]) + + def test_mixed_chained(self): + self.insert_row() + + handlers = [ + Mock(side_effect=identity), + AsyncMock(side_effect=identity), + Mock(side_effect=identity), + ] + + ( + Band.select(Band.name) + .callback(handlers[0]) + .callback(handlers[1]) + .callback(handlers[2]) + .run_sync() + ) + + for handler in handlers: + handler.assert_called_once_with([{"name": "Pythonistas"}]) + + +class TestCallbackTransformDataSelect(DBTestCase): + def test_transform(self): + self.insert_row() + + result = ( + Band.select(Band.name) + .first() + .callback([get_name, uppercase, limit]) + .run_sync() + ) + + self.assertEqual(result, "PYTHON") + + def test_transform_chain(self): + self.insert_row() + + result = ( + Band.select(Band.name) + .first() + .callback(get_name) + .callback(uppercase) + .callback(limit) + .run_sync() + ) + + self.assertEqual(result, "PYTHON") + + def test_transform_mixed(self): + self.insert_row() + + result = ( + Band.select(Band.name) + .first() + .callback([get_name, uppercase]) + .callback(limit) + .run_sync() + ) + + self.assertEqual(result, "PYTHON") diff --git a/tests/table/test_constructor.py b/tests/table/test_constructor.py new file mode 100644 index 000000000..78263e423 --- /dev/null +++ b/tests/table/test_constructor.py @@ -0,0 +1,28 @@ +from unittest import TestCase + +from tests.example_apps.music.tables import Band + + +class TestConstructor(TestCase): + def test_data_parameter(self): + """ + Make sure the _data parameter works. + """ + band = Band({Band.name: "Pythonistas"}) + self.assertEqual(band.name, "Pythonistas") + + def test_kwargs(self): + """ + Make sure kwargs works. + """ + band = Band(name="Pythonistas") + self.assertEqual(band.name, "Pythonistas") + + def test_mix(self): + """ + Make sure the _data paramter and kwargs works together (it's unlikely + people will do this, but just in case). + """ + band = Band({Band.name: "Pythonistas"}, popularity=1000) + self.assertEqual(band.name, "Pythonistas") + self.assertEqual(band.popularity, 1000) diff --git a/tests/table/test_count.py b/tests/table/test_count.py index 037fab5bc..7da90bea0 100644 --- a/tests/table/test_count.py +++ b/tests/table/test_count.py @@ -1,11 +1,77 @@ -from ..base import DBTestCase -from ..example_app.tables import Band +from unittest import TestCase +from piccolo.columns import Integer, Varchar +from piccolo.table import Table -class TestCount(DBTestCase): - def test_exists(self): - self.insert_rows() - response = Band.count().where(Band.name == "Pythonistas").run_sync() +class Band(Table): + name = Varchar() + popularity = Integer() - self.assertTrue(response == 1) + +class TestCount(TestCase): + def setUp(self) -> None: + Band.create_table().run_sync() + + def tearDown(self) -> None: + Band.alter().drop_table().run_sync() + + def test_count(self): + Band.insert( + Band(name="Pythonistas", popularity=10), + Band(name="Rustaceans", popularity=10), + Band(name="C-Sharps", popularity=5), + ).run_sync() + + response = Band.count().run_sync() + + self.assertEqual(response, 3) + + def test_count_where(self): + Band.insert( + Band(name="Pythonistas", popularity=10), + Band(name="Rustaceans", popularity=10), + Band(name="C-Sharps", popularity=5), + ).run_sync() + + response = Band.count().where(Band.popularity == 10).run_sync() + + self.assertEqual(response, 2) + + def test_count_distinct(self): + Band.insert( + Band(name="Pythonistas", popularity=10), + Band(name="Rustaceans", popularity=10), + Band(name="C-Sharps", popularity=5), + Band(name="Fortranists", popularity=2), + ).run_sync() + + response = Band.count(distinct=[Band.popularity]).run_sync() + + self.assertEqual(response, 3) + + # Test the method also works + response = Band.count().distinct([Band.popularity]).run_sync() + self.assertEqual(response, 3) + + def test_count_distinct_multiple(self): + Band.insert( + Band(name="Pythonistas", popularity=10), + Band(name="Pythonistas", popularity=10), + Band(name="Rustaceans", popularity=10), + Band(name="C-Sharps", popularity=5), + Band(name="Fortranists", popularity=2), + ).run_sync() + + response = Band.count(distinct=[Band.name, Band.popularity]).run_sync() + + self.assertEqual(response, 4) + + def test_value_error(self): + """ + Make sure specifying `column` and `distinct` raises an error. + """ + with self.assertRaises(ValueError): + Band.count( + column=Band.name, distinct=[Band.name, Band.popularity] + ).run_sync() diff --git a/tests/table/test_create.py b/tests/table/test_create.py index b8cd90034..7dd936e59 100644 --- a/tests/table/test_create.py +++ b/tests/table/test_create.py @@ -1,9 +1,10 @@ from unittest import TestCase from piccolo.columns import Varchar +from piccolo.schema import SchemaManager from piccolo.table import Table - -from ..example_app.tables import Manager +from tests.base import engines_only +from tests.example_apps.music.tables import Manager class TestCreate(TestCase): @@ -28,7 +29,7 @@ def tearDown(self): def test_create_table_with_indexes(self): index_names = BandMember.indexes().run_sync() index_name = BandMember._get_index_name(["name"]) - self.assertTrue(index_name in index_names) + self.assertIn(index_name, index_names) def test_create_if_not_exists_with_indexes(self): """ @@ -42,10 +43,55 @@ def test_create_if_not_exists_with_indexes(self): query.run_sync() self.assertTrue( - query.querystrings[0] - .__str__() - .startswith("CREATE TABLE IF NOT EXISTS"), - query.querystrings[1] - .__str__() - .startswith("CREATE INDEX IF NOT EXISTS"), + query.ddl[0].__str__().startswith("CREATE TABLE IF NOT EXISTS"), + query.ddl[1].__str__().startswith("CREATE INDEX IF NOT EXISTS"), + ) + + +@engines_only("postgres", "cockroach") +class TestCreateWithSchema(TestCase): + + manager = SchemaManager() + + class Band(Table, tablename="band", schema="schema_1"): + name = Varchar(length=50, index=True) + + def tearDown(self) -> None: + self.manager.drop_schema( + schema_name="schema_1", cascade=True + ).run_sync() + + def test_table_created(self): + """ + Make sure that tables can be created in specific schemas. + """ + Band = self.Band + Band.create_table().run_sync() + + self.assertIn( + "band", + self.manager.list_tables(schema_name="schema_1").run_sync(), + ) + + +@engines_only("postgres", "cockroach") +class TestCreateWithPublicSchema(TestCase): + class Band(Table, tablename="band", schema="public"): + name = Varchar(length=50, index=True) + + def tearDown(self) -> None: + self.Band.alter().drop_table(if_exists=True).run_sync() + + def test_table_created(self): + """ + Make sure that if the schema is explicitly set a 'public' rather than + ``None``, that we don't try creating the schema which would cause it + to fail. + """ + Band = self.Band + Band.create_table().run_sync() + + self.assertIn( + "band", + SchemaManager().list_tables(schema_name="public").run_sync(), ) diff --git a/tests/table/test_create_db_tables.py b/tests/table/test_create_db_tables.py new file mode 100644 index 000000000..fdcf2a5d5 --- /dev/null +++ b/tests/table/test_create_db_tables.py @@ -0,0 +1,29 @@ +from unittest import TestCase + +from piccolo.table import ( + create_db_tables_sync, + create_tables, + drop_db_tables_sync, +) +from tests.example_apps.music.tables import Band, Manager + + +class TestCreateDBTables(TestCase): + def tearDown(self) -> None: + drop_db_tables_sync(Manager, Band) + + def test_create_db_tables(self): + """ + Make sure the tables are created in the database. + """ + create_db_tables_sync(Manager, Band, if_not_exists=False) + self.assertTrue(Manager.table_exists().run_sync()) + self.assertTrue(Band.table_exists().run_sync()) + + def test_create_tables(self): + """ + This is a deprecated function, which just acts as a proxy. + """ + create_tables(Manager, Band, if_not_exists=False) + self.assertTrue(Manager.table_exists().run_sync()) + self.assertTrue(Band.table_exists().run_sync()) diff --git a/tests/table/test_create_table_class.py b/tests/table/test_create_table_class.py index 1c891e443..ef18f6b17 100644 --- a/tests/table/test_create_table_class.py +++ b/tests/table/test_create_table_class.py @@ -1,7 +1,8 @@ from unittest import TestCase +from unittest.mock import patch from piccolo.columns import Varchar -from piccolo.table import create_table_class +from piccolo.table import TABLENAME_WARNING, create_table_class class TestCreateTableClass(TestCase): @@ -21,22 +22,28 @@ def test_create_table_class(self): _Table = create_table_class( class_name="MyTable", class_members={"name": column} ) - self.assertTrue(column in _Table._meta.columns) + self.assertIn(column, _Table._meta.columns) def test_protected_tablenames(self): """ Make sure that the logic around protected tablenames still works as expected. """ - with self.assertRaises(ValueError): + expected_warning = TABLENAME_WARNING.format(tablename="user") + + with patch("piccolo.table.warnings") as warnings: create_table_class(class_name="User") + warnings.warn.assert_called_once_with(expected_warning) - with self.assertRaises(ValueError): + with patch("piccolo.table.warnings") as warnings: create_table_class( class_name="MyUser", class_kwargs={"tablename": "user"} ) + warnings.warn.assert_called_once_with(expected_warning) - # This shouldn't raise an error: - create_table_class( - class_name="User", class_kwargs={"tablename": "my_user"} - ) + # This shouldn't output a warning: + with patch("piccolo.table.warnings") as warnings: + create_table_class( + class_name="User", class_kwargs={"tablename": "my_user"} + ) + warnings.warn.assert_not_called() diff --git a/tests/table/test_delete.py b/tests/table/test_delete.py index 2c5c6d35f..218acd458 100644 --- a/tests/table/test_delete.py +++ b/tests/table/test_delete.py @@ -1,7 +1,8 @@ -from piccolo.query.methods.delete import DeletionError +import pytest -from ..base import DBTestCase -from ..example_app.tables import Band +from piccolo.query.methods.delete import DeletionError +from tests.base import DBTestCase, engine_version_lt, is_running_sqlite +from tests.example_apps.music.tables import Band class TestDelete(DBTestCase): @@ -11,10 +12,30 @@ def test_delete(self): Band.delete().where(Band.name == "CSharps").run_sync() response = Band.count().where(Band.name == "CSharps").run_sync() - print(f"response = {response}") self.assertEqual(response, 0) + @pytest.mark.skipif( + is_running_sqlite() and engine_version_lt(3.35), + reason="SQLite version not supported", + ) + def test_delete_returning(self): + """ + Make sure delete works with the `returning` clause. + """ + + self.insert_rows() + + response = ( + Band.delete() + .where(Band.name == "CSharps") + .returning(Band.name) + .run_sync() + ) + + self.assertEqual(len(response), 1) + self.assertEqual(response, [{"name": "CSharps"}]) + def test_validation(self): """ Make sure you can't delete all the data without forcing it. @@ -23,3 +44,18 @@ def test_validation(self): Band.delete().run_sync() Band.delete(force=True).run_sync() + + def test_delete_with_joins(self): + """ + Make sure delete works if the `where` clause specifies joins. + """ + + self.insert_rows() + + Band.delete().where(Band.manager._.name == "Guido").run_sync() + + response = ( + Band.count().where(Band.manager._.name == "Guido").run_sync() + ) + + self.assertEqual(response, 0) diff --git a/tests/table/test_drop_db_tables.py b/tests/table/test_drop_db_tables.py new file mode 100644 index 000000000..bfbf85890 --- /dev/null +++ b/tests/table/test_drop_db_tables.py @@ -0,0 +1,37 @@ +from unittest import TestCase + +from piccolo.table import ( + create_db_tables_sync, + drop_db_tables_sync, + drop_tables, +) +from tests.example_apps.music.tables import Band, Manager + + +class TestDropTables(TestCase): + def setUp(self): + create_db_tables_sync(Band, Manager) + + def test_drop_db_tables(self): + """ + Make sure the tables are dropped. + """ + self.assertTrue(Manager.table_exists().run_sync()) + self.assertTrue(Band.table_exists().run_sync()) + + drop_db_tables_sync(Manager, Band) + + self.assertFalse(Manager.table_exists().run_sync()) + self.assertFalse(Band.table_exists().run_sync()) + + def test_drop_tables(self): + """ + This is a deprecated function, which just acts as a proxy. + """ + self.assertTrue(Manager.table_exists().run_sync()) + self.assertTrue(Band.table_exists().run_sync()) + + drop_tables(Manager, Band) + + self.assertFalse(Manager.table_exists().run_sync()) + self.assertFalse(Band.table_exists().run_sync()) diff --git a/tests/table/test_exists.py b/tests/table/test_exists.py index b261cdb0e..8d8f07cc0 100644 --- a/tests/table/test_exists.py +++ b/tests/table/test_exists.py @@ -1,5 +1,5 @@ -from ..base import DBTestCase -from ..example_app.tables import Band +from tests.base import DBTestCase +from tests.example_apps.music.tables import Band class TestExists(DBTestCase): @@ -8,4 +8,4 @@ def test_exists(self): response = Band.exists().where(Band.name == "Pythonistas").run_sync() - self.assertTrue(response is True) + self.assertTrue(response) diff --git a/tests/table/test_from_dict.py b/tests/table/test_from_dict.py new file mode 100644 index 000000000..3c931f502 --- /dev/null +++ b/tests/table/test_from_dict.py @@ -0,0 +1,28 @@ +from unittest import TestCase + +from piccolo.columns import Varchar +from piccolo.table import Table + + +class BandMember(Table): + name = Varchar(length=50, index=True) + + +class TestCreateFromDict(TestCase): + def setUp(self): + BandMember.create_table().run_sync() + + def tearDown(self): + BandMember.alter().drop_table().run_sync() + + def test_create_table_from_dict(self): + BandMember.from_dict({"name": "John"}).save().run_sync() + self.assertEqual( + BandMember.select(BandMember.name).run_sync(), [{"name": "John"}] + ) + BandMember.from_dict({"name": "Town"}).save().run_sync() + self.assertEqual(BandMember.count().run_sync(), 2) + self.assertEqual( + BandMember.select(BandMember.name).run_sync(), + [{"name": "John"}, {"name": "Town"}], + ) diff --git a/tests/table/test_indexes.py b/tests/table/test_indexes.py index a4e40f652..a3d44bcf1 100644 --- a/tests/table/test_indexes.py +++ b/tests/table/test_indexes.py @@ -1,10 +1,18 @@ from unittest import TestCase -from ..base import DBTestCase -from ..example_app.tables import Manager +from piccolo.columns.base import Column +from piccolo.columns.column_types import Integer +from piccolo.table import Table +from tests.example_apps.music.tables import Manager -class TestIndexes(DBTestCase): +class TestIndexes(TestCase): + def setUp(self): + Manager.create_table().run_sync() + + def tearDown(self): + Manager.alter().drop_table().run_sync() + def test_create_index(self): """ Test single column and multi column indexes. @@ -20,11 +28,39 @@ def test_create_index(self): ) index_names = Manager.indexes().run_sync() - self.assertTrue(index_name in index_names) + self.assertIn(index_name, index_names) Manager.drop_index(columns).run_sync() index_names = Manager.indexes().run_sync() - self.assertTrue(index_name not in index_names) + self.assertNotIn(index_name, index_names) + + +class Concert(Table): + order = Integer() + + +class TestProblematicColumnName(TestCase): + def setUp(self): + Concert.create_table().run_sync() + + def tearDown(self): + Concert.alter().drop_table().run_sync() + + def test_problematic_name(self) -> None: + """ + Make sure we can add an index to a column with a problematic name + (which clashes with a SQL keyword). + """ + columns: list[Column] = [Concert.order] + Concert.create_index(columns=columns).run_sync() + index_name = Concert._get_index_name([i._meta.name for i in columns]) + + index_names = Concert.indexes().run_sync() + self.assertIn(index_name, index_names) + + Concert.drop_index(columns).run_sync() + index_names = Concert.indexes().run_sync() + self.assertNotIn(index_name, index_names) class TestIndexName(TestCase): diff --git a/tests/table/test_inheritance.py b/tests/table/test_inheritance.py index 8030bb4b7..a7ab2c90e 100644 --- a/tests/table/test_inheritance.py +++ b/tests/table/test_inheritance.py @@ -61,6 +61,7 @@ def test_inheritance(self): ).save().run_sync() response = Manager.select().first().run_sync() + assert response is not None self.assertEqual(response["started_on"], started_on) self.assertEqual(response["name"], name) self.assertEqual(response["favourite"], favourite) @@ -98,6 +99,7 @@ def test_inheritance(self): _Table(name=name, started_on=started_on).save().run_sync() response = _Table.select().first().run_sync() + assert response is not None self.assertEqual(response["started_on"], started_on) self.assertEqual(response["name"], name) diff --git a/tests/table/test_insert.py b/tests/table/test_insert.py index 45eb77bd3..19c1b0acb 100644 --- a/tests/table/test_insert.py +++ b/tests/table/test_insert.py @@ -1,5 +1,21 @@ -from ..base import DBTestCase -from ..example_app.tables import Band, Manager +import sqlite3 +from unittest import TestCase + +import pytest + +from piccolo.columns import Integer, Serial, Varchar +from piccolo.query.methods.insert import OnConflictAction +from piccolo.table import Table +from piccolo.utils.lazy_loader import LazyLoader +from tests.base import ( + DBTestCase, + engine_version_lt, + engines_only, + is_running_sqlite, +) +from tests.example_apps.music.tables import Band, Manager + +asyncpg = LazyLoader("asyncpg", globals(), "asyncpg") class TestInsert(DBTestCase): @@ -11,7 +27,7 @@ def test_insert(self): response = Band.select(Band.name).run_sync() names = [i["name"] for i in response] - self.assertTrue("Rustaceans" in names) + self.assertIn("Rustaceans", names) def test_add(self): self.insert_rows() @@ -21,7 +37,7 @@ def test_add(self): response = Band.select(Band.name).run_sync() names = [i["name"] for i in response] - self.assertTrue("Rustaceans" in names) + self.assertIn("Rustaceans", names) def test_incompatible_type(self): """ @@ -41,4 +57,444 @@ def test_insert_curly_braces(self): response = Band.select(Band.name).run_sync() names = [i["name"] for i in response] - self.assertTrue("{}" in names) + self.assertIn("{}", names) + + @pytest.mark.skipif( + is_running_sqlite() and engine_version_lt(3.35), + reason="SQLite version not supported", + ) + def test_insert_returning(self): + """ + Make sure update works with the `returning` clause. + """ + response = ( + Manager.insert(Manager(name="Maz")) + .returning(Manager.name) + .run_sync() + ) + + self.assertListEqual(response, [{"name": "Maz"}]) + + @pytest.mark.skipif( + is_running_sqlite() and engine_version_lt(3.35), + reason="SQLite version not supported", + ) + def test_insert_returning_alias(self): + """ + Make sure update works with the `returning` clause. + """ + response = ( + Manager.insert(Manager(name="Maz")) + .returning(Manager.name.as_alias("manager_name")) + .run_sync() + ) + + self.assertListEqual(response, [{"manager_name": "Maz"}]) + + +@pytest.mark.skipif( + is_running_sqlite() and engine_version_lt(3.24), + reason="SQLite version not supported", +) +class TestOnConflict(TestCase): + class Band(Table): + id: Serial + name = Varchar(unique=True) + popularity = Integer() + + def setUp(self) -> None: + Band = self.Band + Band.create_table().run_sync() + self.band = Band({Band.name: "Pythonistas", Band.popularity: 1000}) + self.band.save().run_sync() + + def tearDown(self) -> None: + Band = self.Band + Band.alter().drop_table().run_sync() + + def test_do_update(self): + """ + Make sure that `DO UPDATE` works. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + + Band.insert( + Band(name=self.band.name, popularity=new_popularity) + ).on_conflict( + target=Band.name, + action="DO UPDATE", + values=[Band.popularity], + ).run_sync() + + self.assertListEqual( + Band.select().run_sync(), + [ + { + "id": self.band.id, + "name": self.band.name, + "popularity": new_popularity, # changed + } + ], + ) + + def test_do_update_tuple_values(self): + """ + Make sure we can use tuples in ``values``. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + new_name = "Rustaceans" + + Band.insert( + Band( + id=self.band.id, + name=new_name, + popularity=new_popularity, + ) + ).on_conflict( + action="DO UPDATE", + target=Band.id, + values=[ + (Band.name, new_name), + (Band.popularity, new_popularity + 2000), + ], + ).run_sync() + + self.assertListEqual( + Band.select().run_sync(), + [ + { + "id": self.band.id, + "name": new_name, + "popularity": new_popularity + 2000, + } + ], + ) + + def test_do_update_no_target(self): + """ + Make sure that `DO UPDATE` with no `target` raises an exception. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + + with self.assertRaises(ValueError) as manager: + Band.insert( + Band(name=self.band.name, popularity=new_popularity) + ).on_conflict( + action="DO UPDATE", + values=[(Band.popularity, new_popularity + 2000)], + ).run_sync() + + self.assertEqual( + manager.exception.__str__(), + "The `target` option must be provided with DO UPDATE.", + ) + + def test_do_update_no_values(self): + """ + Make sure that `DO UPDATE` with no `values` raises an exception. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + + with self.assertRaises(ValueError) as manager: + Band.insert( + Band(name=self.band.name, popularity=new_popularity) + ).on_conflict( + target=Band.name, + action="DO UPDATE", + ).run_sync() + + self.assertEqual( + manager.exception.__str__(), + "No values specified for `on conflict`", + ) + + @engines_only("postgres", "cockroach") + def test_target_tuple(self): + """ + Make sure that a composite unique constraint can be used as a target. + + We only run it on Postgres and Cockroach because we use ALTER TABLE + to add a constraint, which SQLite doesn't support. + """ + Band = self.Band + + # Add a composite unique constraint: + Band.raw( + "ALTER TABLE band ADD CONSTRAINT id_name_unique UNIQUE (id, name)" + ).run_sync() + + Band.insert( + Band( + id=self.band.id, + name=self.band.name, + popularity=self.band.popularity, + ) + ).on_conflict( + target=(Band.id, Band.name), + action="DO NOTHING", + ).run_sync() + + @engines_only("postgres", "cockroach") + def test_target_string(self): + """ + Make sure we can explicitly specify the name of target constraint using + a string. + + We just test this on Postgres for now, as we have to get the constraint + name from the database. + """ + Band = self.Band + + constraint_name = [ + i["constraint_name"] + for i in Band.raw( + """ + SELECT constraint_name + FROM information_schema.constraint_column_usage + WHERE column_name = 'name' + AND table_name = 'band'; + """ + ).run_sync() + if i["constraint_name"].endswith("_key") + ][0] + + query = Band.insert(Band(name=self.band.name)).on_conflict( + target=constraint_name, + action="DO NOTHING", + ) + self.assertIn(f'ON CONSTRAINT "{constraint_name}"', query.__str__()) + query.run_sync() + + def test_violate_non_target(self): + """ + Make sure that if we specify a target constraint, but violate a + different constraint, then we still get the error. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + + with self.assertRaises(Exception) as manager: + Band.insert( + Band(name=self.band.name, popularity=new_popularity) + ).on_conflict( + target=Band.id, # Target the primary key instead. + action="DO UPDATE", + values=[Band.popularity], + ).run_sync() + + if self.Band._meta.db.engine_type in ("postgres", "cockroach"): + self.assertIsInstance( + manager.exception, asyncpg.exceptions.UniqueViolationError + ) + elif self.Band._meta.db.engine_type == "sqlite": + self.assertIsInstance(manager.exception, sqlite3.IntegrityError) + + def test_where(self): + """ + Make sure we can pass in a `where` argument. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + + query = Band.insert( + Band(name=self.band.name, popularity=new_popularity) + ).on_conflict( + target=Band.name, + action="DO UPDATE", + values=[Band.popularity], + where=Band.popularity < self.band.popularity, + ) + + self.assertIn( + f'WHERE "band"."popularity" < {self.band.popularity}', + query.__str__(), + ) + + query.run_sync() + + def test_do_nothing_where(self): + """ + Make sure an error is raised if `where` is used with `DO NOTHING`. + """ + Band = self.Band + + with self.assertRaises(ValueError) as manager: + Band.insert(Band()).on_conflict( + action="DO NOTHING", + where=Band.popularity < self.band.popularity, + ) + + self.assertEqual( + manager.exception.__str__(), + "The `where` option can only be used with DO NOTHING.", + ) + + def test_do_nothing(self): + """ + Make sure that `DO NOTHING` works. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + + Band.insert( + Band(name="Pythonistas", popularity=new_popularity) + ).on_conflict(action="DO NOTHING").run_sync() + + self.assertListEqual( + Band.select().run_sync(), + [ + { + "id": self.band.id, + "name": self.band.name, + "popularity": self.band.popularity, + } + ], + ) + + @engines_only("sqlite") + def test_multiple_do_update(self): + """ + Make sure multiple `ON CONFLICT` clauses work for SQLite. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + + # Conflicting with name - should update. + Band.insert( + Band(name="Pythonistas", popularity=new_popularity) + ).on_conflict(action="DO NOTHING", target=Band.id).on_conflict( + action="DO UPDATE", target=Band.name, values=[Band.popularity] + ).run_sync() + + self.assertListEqual( + Band.select().run_sync(), + [ + { + "id": self.band.id, + "name": self.band.name, + "popularity": new_popularity, # changed + } + ], + ) + + @engines_only("sqlite") + def test_multiple_do_nothing(self): + """ + Make sure multiple `ON CONFLICT` clauses work for SQLite. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + + # Conflicting with ID - should be ignored. + Band.insert( + Band( + id=self.band.id, + name="Pythonistas", + popularity=new_popularity, + ) + ).on_conflict(action="DO NOTHING", target=Band.id).on_conflict( + action="DO UPDATE", + target=Band.name, + values=[Band.popularity], + ).run_sync() + + self.assertListEqual( + Band.select().run_sync(), + [ + { + "id": self.band.id, + "name": self.band.name, + "popularity": self.band.popularity, + } + ], + ) + + @engines_only("postgres", "cockroach") + def test_mutiple_error(self): + """ + Postgres and Cockroach don't support multiple `ON CONFLICT` clauses. + """ + with self.assertRaises(NotImplementedError) as manager: + Band = self.Band + + Band.insert(Band()).on_conflict(action="DO NOTHING").on_conflict( + action="DO UPDATE", + ).run_sync() + + assert manager.exception.__str__() == ( + "Postgres and Cockroach only support a single ON CONFLICT clause." + ) + + def test_all_columns(self): + """ + We can use ``all_columns`` instead of specifying the ``values`` + manually. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + new_name = "Rustaceans" + + # Conflicting with ID - should be ignored. + q = Band.insert( + Band( + id=self.band.id, + name=new_name, + popularity=new_popularity, + ) + ).on_conflict( + action="DO UPDATE", + target=Band.id, + values=Band.all_columns(), + ) + q.run_sync() + + self.assertListEqual( + Band.select().run_sync(), + [ + { + "id": self.band.id, + "name": new_name, + "popularity": new_popularity, + } + ], + ) + + def test_enum(self): + """ + A string literal can be passed in, or an enum, to determine the action. + Make sure that the enum works. + """ + Band = self.Band + + Band.insert( + Band( + id=self.band.id, + name=self.band.name, + popularity=self.band.popularity, + ) + ).on_conflict(action=OnConflictAction.do_nothing).run_sync() + + self.assertListEqual( + Band.select().run_sync(), + [ + { + "id": self.band.id, + "name": self.band.name, + "popularity": self.band.popularity, + } + ], + ) diff --git a/tests/table/test_join.py b/tests/table/test_join.py index e0d59e098..1d3aba2f3 100644 --- a/tests/table/test_join.py +++ b/tests/table/test_join.py @@ -1,6 +1,14 @@ +import decimal from unittest import TestCase -from ..example_app.tables import Band, Concert, Manager, Venue +from tests.base import engine_is +from tests.example_apps.music.tables import ( + Band, + Concert, + Manager, + Ticket, + Venue, +) TABLES = [Manager, Band, Venue, Concert] @@ -15,25 +23,18 @@ def test_create_join(self): class TestJoin(TestCase): - """ - Test instantiating Table instances - """ - - tables = [Manager, Band, Venue, Concert] + tables = [Manager, Band, Venue, Concert, Ticket] def setUp(self): for table in self.tables: table.create_table().run_sync() - def tearDown(self): - for table in reversed(self.tables): - table.alter().drop_table().run_sync() - - def test_join(self): manager_1 = Manager(name="Guido") manager_1.save().run_sync() - band_1 = Band(name="Pythonistas", manager=manager_1.id) + band_1 = Band( + name="Pythonistas", manager=manager_1.id, popularity=1000 + ) band_1.save().run_sync() manager_2 = Manager(name="Graydon") @@ -42,14 +43,22 @@ def test_join(self): band_2 = Band(name="Rustaceans", manager=manager_2.id) band_2.save().run_sync() - venue = Venue(name="Grand Central") + venue = Venue(name="Grand Central", capacity=1000) venue.save().run_sync() - save_query = Concert( - band_1=band_1.id, band_2=band_2.id, venue=venue.id - ).save() - save_query.run_sync() + concert = Concert(band_1=band_1.id, band_2=band_2.id, venue=venue.id) + concert.save().run_sync() + + ticket = Ticket(concert=concert, price=decimal.Decimal(50.0)) + ticket.save().run_sync() + def tearDown(self): + for table in reversed(self.tables): + table.alter().drop_table().run_sync() + + ########################################################################### + + def test_join(self): select_query = Concert.select( Concert.band_1.name, Concert.band_2.name, @@ -57,19 +66,423 @@ def test_join(self): Concert.band_1.manager, ) response = select_query.run_sync() - self.assertEqual( + + if engine_is("cockroach"): + self.assertEqual( + response, + [ + { + "band_1.name": "Pythonistas", + "band_2.name": "Rustaceans", + "venue.name": "Grand Central", + "band_1.manager": response[0]["band_1.manager"], + } + ], + ) + else: + self.assertEqual( + response, + [ + { + "band_1.name": "Pythonistas", + "band_2.name": "Rustaceans", + "venue.name": "Grand Central", + "band_1.manager": 1, + } + ], + ) + + # Now make sure that even deeper joins work: + select_query = Concert.select(Concert.band_1._.manager._.name) + response = select_query.run_sync() + self.assertEqual(response, [{"band_1.manager.name": "Guido"}]) + + def test_underscore_syntax(self): + """ + Make sure that queries work with the ``._.`` syntax for joins. + """ + response = Concert.select( + Concert.band_1._.name, + Concert.band_1._.manager._.name, + Concert.band_2._.name, + Concert.band_2._.manager._.name, + ).run_sync() + + self.assertListEqual( response, [ { "band_1.name": "Pythonistas", + "band_1.manager.name": "Guido", "band_2.name": "Rustaceans", - "venue.name": "Grand Central", - "band_1.manager": 1, + "band_2.manager.name": "Graydon", } ], ) - # Now make sure that even deeper joins work: - select_query = Concert.select(Concert.band_1.manager.name) - response = select_query.run_sync() - self.assertEqual(response, [{"band_1.manager.name": "Guido"}]) + def test_select_all_columns(self): + """ + Make sure you can retrieve all columns from a related table, without + explicitly specifying them. + """ + result = ( + Band.select(Band.name, *Band.manager.all_columns()) + .first() + .run_sync() + ) + assert result is not None + + if engine_is("cockroach"): + self.assertDictEqual( + result, + { + "name": "Pythonistas", + "manager.id": result["manager.id"], + "manager.name": "Guido", + }, + ) + else: + self.assertDictEqual( + result, + { + "name": "Pythonistas", + "manager.id": 1, + "manager.name": "Guido", + }, + ) + + def test_select_all_columns_deep(self): + """ + Make sure that ``all_columns`` can be used several layers deep. + """ + result = ( + Concert.select( + *Concert.venue.all_columns(), + *Concert.band_1._.manager.all_columns(), + *Concert.band_2._.manager.all_columns(), + ) + .first() + .run_sync() + ) + assert result is not None + + if engine_is("cockroach"): + self.assertDictEqual( + result, + { + "venue.id": result["venue.id"], + "venue.name": "Grand Central", + "venue.capacity": 1000, + "band_1.manager.id": result["band_1.manager.id"], + "band_1.manager.name": "Guido", + "band_2.manager.id": result["band_2.manager.id"], + "band_2.manager.name": "Graydon", + }, + ) + else: + self.assertDictEqual( + result, + { + "venue.id": 1, + "venue.name": "Grand Central", + "venue.capacity": 1000, + "band_1.manager.id": 1, + "band_1.manager.name": "Guido", + "band_2.manager.id": 2, + "band_2.manager.name": "Graydon", + }, + ) + + def test_proxy_columns(self): + """ + Make sure that ``proxy_columns`` are set correctly. + + There used to be a bug which meant queries got slower over time: + + https://github.com/piccolo-orm/piccolo/issues/691 + + """ + # We call it multiple times to make sure it doesn't change with time. + for _ in range(2): + self.assertEqual( + len(Concert.band_1._.manager._foreign_key_meta.proxy_columns), + 2, + ) + self.assertEqual( + len(Concert.band_1._foreign_key_meta.proxy_columns), 4 + ) + + def test_select_all_columns_root(self): + """ + Make sure that using ``all_columns`` at the root doesn't interfere + with using it for referenced tables. + """ + result = ( + Band.select( + *Band.all_columns(), + *Band.manager.all_columns(), + ) + .first() + .run_sync() + ) + assert result is not None + + if engine_is("cockroach"): + self.assertDictEqual( + result, + { + "id": result["id"], + "name": "Pythonistas", + "manager": result["manager"], + "popularity": 1000, + "manager.id": result["manager.id"], + "manager.name": "Guido", + }, + ) + else: + self.assertDictEqual( + result, + { + "id": 1, + "name": "Pythonistas", + "manager": 1, + "popularity": 1000, + "manager.id": 1, + "manager.name": "Guido", + }, + ) + + def test_select_all_columns_root_nested(self): + """ + Make sure that using ``all_columns`` at the root doesn't interfere + with using it for referenced tables. + """ + result = ( + Band.select(*Band.all_columns(), *Band.manager.all_columns()) + .output(nested=True) + .first() + .run_sync() + ) + assert result is not None + + if engine_is("cockroach"): + self.assertDictEqual( + result, + { + "id": result["id"], + "name": "Pythonistas", + "manager": { + "id": result["manager"]["id"], + "name": "Guido", + }, + "popularity": 1000, + }, + ) + else: + self.assertDictEqual( + result, + { + "id": 1, + "name": "Pythonistas", + "manager": {"id": 1, "name": "Guido"}, + "popularity": 1000, + }, + ) + + def test_select_all_columns_exclude(self): + """ + Make sure we can get all columns, except the ones we specify. + """ + result = ( + Band.select( + *Band.all_columns(exclude=[Band.id]), + *Band.manager.all_columns(exclude=[Band.manager.id]), + ) + .output(nested=True) + .first() + .run_sync() + ) + assert result is not None + + result_str_args = ( + Band.select( + *Band.all_columns(exclude=["id"]), + *Band.manager.all_columns(exclude=["id"]), + ) + .output(nested=True) + .first() + .run_sync() + ) + assert result_str_args is not None + + for data in (result, result_str_args): + self.assertDictEqual( + data, + { + "name": "Pythonistas", + "manager": {"name": "Guido"}, + "popularity": 1000, + }, + ) + + ########################################################################### + + def test_objects_nested(self): + """ + Make sure the prefetch argument works correctly for objects. + """ + band = Band.objects(Band.manager).first().run_sync() + assert band is not None + self.assertIsInstance(band.manager, Manager) + + def test_objects__all_related__root(self): + """ + Make sure that ``all_related`` works correctly when called from the + root table of the query. + """ + concert = Concert.objects(Concert.all_related()).first().run_sync() + assert concert is not None + self.assertIsInstance(concert.band_1, Band) + self.assertIsInstance(concert.band_2, Band) + self.assertIsInstance(concert.venue, Venue) + + def test_objects_nested_deep(self): + """ + Make sure that ``prefetch`` works correctly with deeply nested tables. + """ + ticket = ( + Ticket.objects( + Ticket.concert, + Ticket.concert._.band_1, + Ticket.concert._.band_2, + Ticket.concert._.venue, + Ticket.concert._.band_1._.manager, + Ticket.concert._.band_2._.manager, + ) + .first() + .run_sync() + ) + assert ticket is not None + + self.assertIsInstance(ticket.concert, Concert) + self.assertIsInstance(ticket.concert.band_1, Band) + self.assertIsInstance(ticket.concert.band_2, Band) + self.assertIsInstance(ticket.concert.venue, Venue) + self.assertIsInstance(ticket.concert.band_1.manager, Manager) + self.assertIsInstance(ticket.concert.band_2.manager, Manager) + + def test_objects__all_related__deep(self): + """ + Make sure that ``all_related`` works correctly when called on a deeply + nested table. + """ + ticket = ( + Ticket.objects( + Ticket.all_related(), + Ticket.concert.all_related(), + Ticket.concert._.band_1.all_related(), + Ticket.concert._.band_2.all_related(), + ) + .first() + .run_sync() + ) + assert ticket is not None + + self.assertIsInstance(ticket.concert, Concert) + self.assertIsInstance(ticket.concert.band_1, Band) + self.assertIsInstance(ticket.concert.band_2, Band) + self.assertIsInstance(ticket.concert.venue, Venue) + self.assertIsInstance(ticket.concert.band_1.manager, Manager) + self.assertIsInstance(ticket.concert.band_2.manager, Manager) + + def test_objects_prefetch_clause(self): + """ + Make sure that ``prefetch`` clause works correctly. + """ + ticket = ( + Ticket.objects() + .prefetch( + Ticket.all_related(), + Ticket.concert.all_related(), + Ticket.concert._.band_1.all_related(), + Ticket.concert._.band_2.all_related(), + ) + .first() + .run_sync() + ) + assert ticket is not None + + self.assertIsInstance(ticket.concert, Concert) + self.assertIsInstance(ticket.concert.band_1, Band) + self.assertIsInstance(ticket.concert.band_2, Band) + self.assertIsInstance(ticket.concert.venue, Venue) + self.assertIsInstance(ticket.concert.band_1.manager, Manager) + self.assertIsInstance(ticket.concert.band_2.manager, Manager) + + def test_objects_prefetch_intermediate(self): + """ + Make sure when using ``prefetch`` on a deeply nested table, all of the + intermediate objects are also retrieved properly. + """ + ticket = ( + Ticket.objects() + .prefetch( + Ticket.concert._.band_1._.manager, + ) + .first() + .run_sync() + ) + assert ticket is not None + + self.assertIsInstance(ticket.price, decimal.Decimal) + self.assertIsInstance(ticket.concert, Concert) + + self.assertIsInstance(ticket.concert.id, int) + self.assertIsInstance(ticket.concert.band_1, Band) + self.assertIsInstance(ticket.concert.band_2, int) + self.assertIsInstance(ticket.concert.venue, int) + + self.assertIsInstance(ticket.concert.band_1.id, int) + self.assertIsInstance(ticket.concert.band_1.name, str) + self.assertIsInstance(ticket.concert.band_1.manager, Manager) + + self.assertIsInstance(ticket.concert.band_1.manager.id, int) + self.assertIsInstance(ticket.concert.band_1.manager.name, str) + + def test_objects_prefetch_multiple_intermediate(self): + """ + Make sure that if we're fetching multiple deeply nested tables, the + intermediate tables are still created correctly. + """ + ticket = ( + Ticket.objects() + .prefetch( + Ticket.concert._.band_1._.manager, + Ticket.concert._.band_2._.manager, + ) + .first() + .run_sync() + ) + assert ticket is not None + + self.assertIsInstance(ticket.price, decimal.Decimal) + self.assertIsInstance(ticket.concert, Concert) + + self.assertIsInstance(ticket.concert.id, int) + self.assertIsInstance(ticket.concert.band_1, Band) + self.assertIsInstance(ticket.concert.band_2, Band) + self.assertIsInstance(ticket.concert.venue, int) + + self.assertIsInstance(ticket.concert.band_1.id, int) + self.assertIsInstance(ticket.concert.band_1.name, str) + self.assertIsInstance(ticket.concert.band_1.manager, Manager) + + self.assertIsInstance(ticket.concert.band_1.manager.id, int) + self.assertIsInstance(ticket.concert.band_1.manager.name, str) + + self.assertIsInstance(ticket.concert.band_2.id, int) + self.assertIsInstance(ticket.concert.band_2.name, str) + self.assertIsInstance(ticket.concert.band_2.manager, Manager) + + self.assertIsInstance(ticket.concert.band_2.manager.id, int) + self.assertIsInstance(ticket.concert.band_2.manager.name, str) diff --git a/tests/table/test_join_on.py b/tests/table/test_join_on.py new file mode 100644 index 000000000..5be16158c --- /dev/null +++ b/tests/table/test_join_on.py @@ -0,0 +1,109 @@ +from unittest import TestCase + +from piccolo.columns import Serial, Varchar +from piccolo.table import Table + + +class Manager(Table): + id: Serial + name = Varchar(unique=True) + email = Varchar(unique=True) + + +class Band(Table): + id: Serial + name = Varchar(unique=True) + manager_name = Varchar() + + +class Concert(Table): + id: Serial + title = Varchar() + band_name = Varchar() + + +class TestJoinOn(TestCase): + tables = [Manager, Band, Concert] + + def setUp(self): + for table in self.tables: + table.create_table().run_sync() + + Manager.insert( + Manager(name="Guido", email="guido@example.com"), + Manager(name="Maz", email="maz@example.com"), + Manager(name="Graydon", email="graydon@example.com"), + ).run_sync() + + Band.insert( + Band(name="Pythonistas", manager_name="Guido"), + Band(name="Rustaceans", manager_name="Graydon"), + ).run_sync() + + Concert.insert( + Concert( + title="Rockfest", + band_name="Pythonistas", + ), + ).run_sync() + + def tearDown(self): + for table in self.tables: + table.alter().drop_table().run_sync() + + def test_join_on(self): + """ + Do a simple join between two tables. + """ + query = Band.select( + Band.name, + Band.manager_name, + Band.manager_name.join_on(Manager.name).email.as_alias( + "manager_email" + ), + ).order_by(Band.id) + + response = query.run_sync() + + self.assertListEqual( + response, + [ + { + "name": "Pythonistas", + "manager_name": "Guido", + "manager_email": "guido@example.com", + }, + { + "name": "Rustaceans", + "manager_name": "Graydon", + "manager_email": "graydon@example.com", + }, + ], + ) + + def test_deeper_join(self): + """ + Do a join between three tables. + """ + response = ( + Concert.select( + Concert.title, + Concert.band_name, + Concert.band_name.join_on(Band.name) + .manager_name.join_on(Manager.name) + .email.as_alias("manager_email"), + ) + .order_by(Concert.id) + .run_sync() + ) + + self.assertListEqual( + response, + [ + { + "title": "Rockfest", + "band_name": "Pythonistas", + "manager_email": "guido@example.com", + } + ], + ) diff --git a/tests/table/test_metaclass.py b/tests/table/test_metaclass.py index f6a274eac..7ff186d73 100644 --- a/tests/table/test_metaclass.py +++ b/tests/table/test_metaclass.py @@ -1,9 +1,17 @@ from unittest import TestCase +from unittest.mock import MagicMock, patch -from piccolo.columns import Secret -from piccolo.columns.column_types import JSON, JSONB, ForeignKey -from piccolo.table import Table -from tests.example_app.tables import Band +from piccolo.columns.column_types import ( + JSON, + JSONB, + Array, + Email, + ForeignKey, + Secret, + Varchar, +) +from piccolo.table import TABLENAME_WARNING, Table +from tests.example_apps.music.tables import Band class TestMetaClass(TestCase): @@ -15,16 +23,22 @@ def test_protected_table_names(self): Some tablenames are forbidden because they're reserved words in the database, and can potentially cause issues. """ - with self.assertRaises(ValueError): + expected_warning = TABLENAME_WARNING.format(tablename="user") + + with patch("piccolo.table.warnings") as warnings: class User(Table): pass - with self.assertRaises(ValueError): + warnings.warn.assert_called_with(expected_warning) + + with patch("piccolo.table.warnings") as warnings: class MyUser(Table, tablename="user"): pass + warnings.warn.assert_called_with(expected_warning) + def test_help_text(self): """ Make sure help_text can be set for the Table. @@ -36,6 +50,38 @@ class Manager(Table, help_text=help_text): self.assertEqual(Manager._meta.help_text, help_text) + def test_schema(self): + """ + Make sure schema can be set for the Table. + """ + schema = "schema_1" + + class Manager(Table, schema=schema): + pass + + self.assertEqual(Manager._meta.schema, schema) + + @patch("piccolo.table.warnings") + def test_schema_from_tablename(self, warnings: MagicMock): + """ + If the tablename contains a '.' we extract the schema name. + """ + table = "manager" + schema = "schema_1" + + tablename = f"{schema}.{table}" + + class Manager(Table, tablename=tablename): + pass + + self.assertEqual(Manager._meta.schema, schema) + self.assertEqual(Manager._meta.tablename, table) + + warnings.warn.assert_called_once_with( + "There's a '.' in the tablename - please use the `schema` " + "argument instead." + ) + def test_foreign_key_columns(self): """ Make sure TableMeta.foreign_keys and TableMeta.foreign_key_references @@ -53,14 +99,18 @@ class TableB(Table): def test_secret_columns(self): """ - Make sure TableMeta.secret_columns are setup correctly. + Make sure TableMeta.secret_columns are setup correctly with the + ``secret=True`` argument and ``Secret`` column type. """ class Classified(Table): top_secret = Secret() + confidential = Varchar(secret=True) + public = Varchar() self.assertEqual( - Classified._meta.secret_columns, [Classified.top_secret] + Classified._meta.secret_columns, + [Classified.top_secret, Classified.confidential], ) def test_json_columns(self): @@ -76,6 +126,28 @@ class MyTable(Table): MyTable._meta.json_columns, [MyTable.column_a, MyTable.column_b] ) + def test_email_columns(self): + """ + Make sure ``TableMeta.email_columns`` are setup correctly. + """ + + class MyTable(Table): + column_a = Email() + column_b = Varchar() + + self.assertEqual(MyTable._meta.email_columns, [MyTable.column_a]) + + def test_arry_columns(self): + """ + Make sure ``TableMeta.array_columns`` are setup correctly. + """ + + class MyTable(Table): + column_a = Array(Varchar()) + column_b = Varchar() + + self.assertEqual(MyTable._meta.array_columns, [MyTable.column_a]) + def test_id_column(self): """ Makes sure an id column is added. diff --git a/tests/table/test_objects.py b/tests/table/test_objects.py index 9269362cc..9f065eecc 100644 --- a/tests/table/test_objects.py +++ b/tests/table/test_objects.py @@ -1,42 +1,46 @@ -from ..base import DBTestCase, postgres_only, sqlite_only -from ..example_app.tables import Band +from piccolo.columns.column_types import ForeignKey +from piccolo.testing.test_case import AsyncTableTest +from tests.base import DBTestCase, engines_only, sqlite_only +from tests.example_apps.music.tables import Band, Manager -class TestObjects(DBTestCase): +class TestGetAll(DBTestCase): def test_get_all(self): self.insert_row() response = Band.objects().run_sync() - self.assertTrue(len(response) == 1) + self.assertEqual(len(response), 1) instance = response[0] - self.assertTrue(isinstance(instance, Band)) - self.assertTrue(instance.name == "Pythonistas") + self.assertIsInstance(instance, Band) + self.assertEqual(instance.name, "Pythonistas") # Now try changing the value and saving it. instance.name = "Rustaceans" save_query = instance.save() save_query.run_sync() - self.assertTrue( - Band.select(Band.name).output(as_list=True).run_sync()[0] - == "Rustaceans" + self.assertEqual( + Band.select(Band.name).output(as_list=True).run_sync()[0], + "Rustaceans", ) - @postgres_only + +class TestOffset(DBTestCase): + @engines_only("postgres", "cockroach") def test_offset_postgres(self): """ Postgres can do an offset without a limit clause. """ self.insert_rows() - response = Band.objects().order_by(Band.name).offset(1).run_sync() - print(f"response = {response}") + response = Band.objects().order_by(Band.name).offset(1).run_sync() self.assertEqual( - [i.name for i in response], ["Pythonistas", "Rustaceans"] + [i.name for i in response], + ["Pythonistas", "Rustaceans"], ) @sqlite_only @@ -54,8 +58,244 @@ def test_offset_sqlite(self): response = query.run_sync() - print(f"response = {response}") - self.assertEqual( [i.name for i in response], ["Pythonistas", "Rustaceans"] ) + + +class TestGet(DBTestCase): + def test_get(self): + self.insert_row() + + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None + + self.assertEqual(band.name, "Pythonistas") + + def test_get_prefetch(self): + self.insert_rows() + + # With prefetch clause + band = ( + Band.objects() + .get((Band.name == "Pythonistas")) + .prefetch(Band.manager) + .run_sync() + ) + assert band is not None + self.assertIsInstance(band.manager, Manager) # type: ignore + + # Just passing it straight into objects + band = ( + Band.objects(Band.manager) + .get((Band.name == "Pythonistas")) + .run_sync() + ) + assert band is not None + self.assertIsInstance(band.manager, Manager) + + +class TestGetOrCreate(DBTestCase): + def test_simple_where_clause(self): + """ + Make sure `get_or_create` works for simple where clauses. + """ + # When the row doesn't exist in the db: + Band.objects().get_or_create( + Band.name == "Pink Floyd", + defaults={"popularity": 100}, # type: ignore + ).run_sync() + + instance = ( + Band.objects().where(Band.name == "Pink Floyd").first().run_sync() + ) + assert instance is not None + + self.assertIsInstance(instance, Band) + self.assertEqual(instance.name, "Pink Floyd") + self.assertEqual(instance.popularity, 100) + + # When the row already exists in the db: + Band.objects().get_or_create( + Band.name == "Pink Floyd", defaults={Band.popularity: 100} + ).run_sync() + + instance = ( + Band.objects().where(Band.name == "Pink Floyd").first().run_sync() + ) + assert instance is not None + + self.assertIsInstance(instance, Band) + self.assertEqual(instance.name, "Pink Floyd") + self.assertEqual(instance.popularity, 100) + + def test_complex_where_clause(self): + """ + Make sure `get_or_create` works with complex where clauses. + """ + self.insert_rows() + + # When the row already exists in the db: + instance = ( + Band.objects() + .get_or_create( + (Band.name == "Pythonistas") & (Band.popularity == 1000) + ) + .run_sync() + ) + self.assertIsInstance(instance, Band) + self.assertEqual(instance._was_created, False) + + # When the row doesn't exist in the db: + instance = ( + Band.objects() + .get_or_create( + (Band.name == "Pythonistas2") & (Band.popularity == 2000) + ) + .run_sync() + ) + self.assertIsInstance(instance, Band) + self.assertEqual(instance._was_created, True) + + def test_very_complex_where_clause(self): + """ + Make sure `get_or_create` works with very complex where clauses. + """ + self.insert_rows() + + # When the row already exists in the db: + instance = ( + Band.objects() + .get_or_create( + (Band.name == "Pythonistas") + & (Band.popularity > 0) + & (Band.popularity < 5000) + ) + .run_sync() + ) + self.assertIsInstance(instance, Band) + self.assertEqual(instance._was_created, False) + + # When the row doesn't exist in the db: + instance = ( + Band.objects() + .get_or_create( + (Band.name == "Pythonistas2") + & (Band.popularity > 10) + & (Band.popularity < 5000) + ) + .run_sync() + ) + self.assertIsInstance(instance, Band) + self.assertEqual(instance._was_created, True) + + # The values in the > and < should be ignored, and the default should + # be used for the column. + self.assertEqual(instance.popularity, 0) + + def test_joins(self): + """ + Make sure that that `get_or_create` creates rows correctly when using + joins. + """ + instance = ( + Band.objects() + .get_or_create( + (Band.name == "My new band") + & (Band.manager.name == "Excellent manager") + ) + .run_sync() + ) + self.assertIsInstance(instance, Band) + self.assertEqual(instance._was_created, True) + + # We want to make sure the band name isn't 'Excellent manager' by + # mistake. + self.assertEqual(Band.name, "My new band") + + def test_prefetch_existing_object(self): + """ + Make sure that that `get_or_create` works with the `prefetch` clause, + when it's an existing row in the database. + """ + self.insert_rows() + + # With prefetch clause + band = ( + Band.objects() + .get_or_create((Band.name == "Pythonistas")) + .prefetch(Band.manager) + .run_sync() + ) + self.assertIsInstance(band.manager, Manager) # type: ignore + self.assertEqual(band.manager.name, "Guido") # type: ignore + + # Just passing it straight into objects + band = ( + Band.objects(Band.manager) + .get_or_create((Band.name == "Pythonistas")) + .run_sync() + ) + self.assertIsInstance(band.manager, Manager) + self.assertEqual(band.manager.name, "Guido") + + def test_prefetch_new_object(self): + """ + Make sure that that `get_or_create` works with the `prefetch` clause, + when the row is being created in the database. + """ + manager = Manager({Manager.name: "Guido"}) + manager.save().run_sync() + + # With prefetch clause + band = ( + Band.objects() + .get_or_create( + (Band.name == "New Band") & (Band.manager == manager) + ) + .prefetch(Band.manager) + .run_sync() + ) + self.assertIsInstance(band.manager, Manager) # type: ignore + self.assertEqual(band.name, "New Band") # type: ignore + + # Just passing it straight into objects + band = ( + Band.objects(Band.manager) + .get_or_create( + (Band.name == "New Band 2") & (Band.manager == manager) + ) + .run_sync() + ) + self.assertIsInstance(band.manager, Manager) + self.assertEqual(band.name, "New Band 2") + self.assertEqual(band.manager.name, "Guido") + + +class BandNotNull(Band, tablename="band"): + manager = ForeignKey(Manager, null=False) + + +class TestGetOrCreateNotNull(AsyncTableTest): + + tables = [BandNotNull, Manager] + + async def test_not_null(self): + """ + There was a bug where `get_or_create` would fail for columns with + `default=None` and `null=False`, even if the value for those columns + was specified in the where clause. + + https://github.com/piccolo-orm/piccolo/issues/1152 + + """ + + manager = Manager({Manager.name: "Test"}) + await manager.save() + + self.assertIsInstance( + await BandNotNull.objects().get_or_create( + BandNotNull.manager == manager + ), + BandNotNull, + ) diff --git a/tests/table/test_output.py b/tests/table/test_output.py index 2eeccaa2f..ecfc997bc 100644 --- a/tests/table/test_output.py +++ b/tests/table/test_output.py @@ -1,8 +1,9 @@ import json from unittest import TestCase +from piccolo.table import create_db_tables_sync, drop_db_tables_sync from tests.base import DBTestCase -from tests.example_app.tables import Band, RecordingStudio +from tests.example_apps.music.tables import Band, Instrument, RecordingStudio class TestOutputList(DBTestCase): @@ -10,7 +11,7 @@ def test_output_as_list(self): self.insert_row() response = Band.select(Band.name).output(as_list=True).run_sync() - self.assertTrue(response == ["Pythonistas"]) + self.assertEqual(response, ["Pythonistas"]) # Make sure that if no rows are found, an empty list is returned. empty_response = ( @@ -19,7 +20,7 @@ def test_output_as_list(self): .output(as_list=True) .run_sync() ) - self.assertTrue(empty_response == []) + self.assertEqual(empty_response, []) class TestOutputJSON(DBTestCase): @@ -28,34 +29,132 @@ def test_output_as_json(self): response = Band.select(Band.name).output(as_json=True).run_sync() - self.assertTrue(json.loads(response) == [{"name": "Pythonistas"}]) + self.assertEqual(json.loads(response), [{"name": "Pythonistas"}]) class TestOutputLoadJSON(TestCase): + tables = [RecordingStudio, Instrument] + json = {"a": 123} + def setUp(self): - RecordingStudio.create_table().run_sync() + create_db_tables_sync(*self.tables) + + recording_studio = RecordingStudio( + { + RecordingStudio.facilities: self.json, + RecordingStudio.facilities_b: self.json, + } + ) + recording_studio.save().run_sync() + + instrument = Instrument( + { + Instrument.recording_studio: recording_studio, + Instrument.name: "Piccolo", + } + ) + instrument.save().run_sync() def tearDown(self): - RecordingStudio.alter().drop_table().run_sync() + drop_db_tables_sync(*self.tables) def test_select(self): - json = {"a": 123} + results = ( + RecordingStudio.select( + RecordingStudio.facilities, RecordingStudio.facilities_b + ) + .output(load_json=True) + .run_sync() + ) - RecordingStudio(facilities=json, facilities_b=json).save().run_sync() + self.assertEqual( + results, + [ + { + "facilities": self.json, + "facilities_b": self.json, + } + ], + ) + + def test_join(self): + """ + Make sure it works correctly when the JSON column is on a joined table. - results = RecordingStudio.select().output(load_json=True).run_sync() + https://github.com/piccolo-orm/piccolo/issues/1001 + + """ + results = ( + Instrument.select( + Instrument.name, + Instrument.recording_studio._.facilities, + ) + .output(load_json=True) + .run_sync() + ) self.assertEqual( results, - [{"id": 1, "facilities": {"a": 123}, "facilities_b": {"a": 123}}], + [ + { + "name": "Piccolo", + "recording_studio.facilities": self.json, + } + ], ) - def test_objects(self): - json = {"a": 123} + def test_join_with_alias(self): + results = ( + Instrument.select( + Instrument.name, + Instrument.recording_studio._.facilities.as_alias( + "facilities" + ), + ) + .output(load_json=True) + .run_sync() + ) - RecordingStudio(facilities=json, facilities_b=json).save().run_sync() + self.assertEqual( + results, + [ + { + "name": "Piccolo", + "facilities": self.json, + } + ], + ) + def test_objects(self): results = RecordingStudio.objects().output(load_json=True).run_sync() + self.assertEqual(results[0].facilities, self.json) + self.assertEqual(results[0].facilities_b, self.json) + - self.assertEqual(results[0].facilities, json) - self.assertEqual(results[0].facilities_b, json) +class TestOutputNested(DBTestCase): + def test_output_nested(self): + self.insert_row() + + response = ( + Band.select(Band.name, Band.manager.name) + .output(nested=True) + .run_sync() + ) + self.assertEqual( + response, [{"name": "Pythonistas", "manager": {"name": "Guido"}}] + ) + + def test_output_nested_with_first(self): + self.insert_row() + + response = ( + Band.select(Band.name, Band.manager.name) + .first() + .output(nested=True) + .run_sync() + ) + assert response is not None + self.assertDictEqual( + response, # type: ignore + {"name": "Pythonistas", "manager": {"name": "Guido"}}, + ) diff --git a/tests/table/test_raw.py b/tests/table/test_raw.py index dc6ef530b..09abb9de8 100644 --- a/tests/table/test_raw.py +++ b/tests/table/test_raw.py @@ -1,27 +1,60 @@ -from ..base import DBTestCase -from ..example_app.tables import Band +from tests.base import DBTestCase, engine_is +from tests.example_apps.music.tables import Band class TestRaw(DBTestCase): def test_raw_without_args(self): self.insert_row() - response = Band.raw("select * from band").run_sync() + response = Band.raw("SELECT * FROM band").run_sync() - self.assertDictEqual( - response[0], - {"id": 1, "name": "Pythonistas", "manager": 1, "popularity": 1000}, - ) + if engine_is("cockroach"): + self.assertDictEqual( + response[0], + { + "id": response[0]["id"], + "name": "Pythonistas", + "manager": response[0]["manager"], + "popularity": 1000, + }, + ) + else: + self.assertDictEqual( + response[0], + { + "id": 1, + "name": "Pythonistas", + "manager": 1, + "popularity": 1000, + }, + ) def test_raw_with_args(self): self.insert_rows() response = Band.raw( - "select * from band where name = {}", "Pythonistas" + "SELECT * FROM band WHERE name = {}", "Pythonistas" ).run_sync() - self.assertTrue(len(response) == 1) - self.assertDictEqual( - response[0], - {"id": 1, "name": "Pythonistas", "manager": 1, "popularity": 1000}, - ) + self.assertEqual(len(response), 1) + + if engine_is("cockroach"): + self.assertDictEqual( + response[0], + { + "id": response[0]["id"], + "name": "Pythonistas", + "manager": response[0]["manager"], + "popularity": 1000, + }, + ) + else: + self.assertDictEqual( + response[0], + { + "id": 1, + "name": "Pythonistas", + "manager": 1, + "popularity": 1000, + }, + ) diff --git a/tests/table/test_ref.py b/tests/table/test_ref.py index 875fa74a2..0d6645eff 100644 --- a/tests/table/test_ref.py +++ b/tests/table/test_ref.py @@ -1,10 +1,10 @@ from unittest import TestCase from piccolo.columns.column_types import Varchar -from tests.example_app.tables import Band +from tests.example_apps.music.tables import Band class TestRef(TestCase): def test_ref(self): column = Band.ref("manager.name") - self.assertTrue(isinstance(column, Varchar)) + self.assertIsInstance(column, Varchar) diff --git a/tests/table/test_refresh.py b/tests/table/test_refresh.py new file mode 100644 index 000000000..ce002bb9a --- /dev/null +++ b/tests/table/test_refresh.py @@ -0,0 +1,298 @@ +from typing import cast + +from piccolo.testing.test_case import TableTest +from tests.base import DBTestCase +from tests.example_apps.music.tables import ( + Band, + Concert, + Manager, + RecordingStudio, + Venue, +) + + +class TestRefresh(DBTestCase): + def setUp(self): + super().setUp() + self.insert_rows() + + def test_refresh(self) -> None: + """ + Make sure ``refresh`` works, with no columns specified. + """ + # Fetch an instance from the database. + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None + initial_data = band.to_dict() + + # Modify the data in the database. + Band.update( + {Band.name: Band.name + "!!!", Band.popularity: 8000} + ).where(Band.name == "Pythonistas").run_sync() + + # Refresh `band`, and make sure it has the correct data. + band.refresh().run_sync() + + self.assertEqual(band.name, "Pythonistas!!!") + self.assertEqual(band.popularity, 8000) + self.assertEqual(band.id, initial_data["id"]) + + def test_refresh_with_prefetch(self) -> None: + """ + Make sure ``refresh`` works, when the object used prefetch to get + nested objets (the nested objects should be updated too). + """ + band = ( + Band.objects(Band.manager) + .where(Band.name == "Pythonistas") + .first() + .run_sync() + ) + assert band is not None + + # Modify the data in the database. + Manager.update({Manager.name: "Guido!!!"}).where( + Manager.name == "Guido" + ).run_sync() + + # Refresh `band`, and make sure it has the correct data. + band.refresh().run_sync() + + self.assertEqual(band.manager.name, "Guido!!!") + + def test_refresh_with_prefetch_multiple_layers_deep(self) -> None: + """ + Make sure ``refresh`` works, when the object used prefetch to get + nested objets (the nested objects should be updated too). + """ + band = ( + Band.objects(Band.manager) + .where(Band.name == "Pythonistas") + .first() + .run_sync() + ) + assert band is not None + + # Modify the data in the database. + Manager.update({Manager.name: "Guido!!!"}).where( + Manager.name == "Guido" + ).run_sync() + + # Refresh `band`, and make sure it has the correct data. + band.refresh().run_sync() + + self.assertEqual(band.manager.name, "Guido!!!") + + def test_columns(self) -> None: + """ + Make sure ``refresh`` works, when columns are specified. + """ + # Fetch an instance from the database. + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None + initial_data = band.to_dict() + + # Modify the data in the database. + Band.update( + {Band.name: Band.name + "!!!", Band.popularity: 8000} + ).where(Band.name == "Pythonistas").run_sync() + + # Refresh `band`, and make sure it has the correct data. + query = band.refresh(columns=[Band.name]) + self.assertEqual( + [i._meta.name for i in query._columns], + ["name"], + ) + query.run_sync() + + self.assertEqual(band.name, "Pythonistas!!!") + self.assertEqual(band.popularity, initial_data["popularity"]) + self.assertEqual(band.id, initial_data["id"]) + + def test_error_when_not_in_db(self) -> None: + """ + Make sure we can't refresh an instance which hasn't been saved in the + database. + """ + band = Band() + + with self.assertRaises(ValueError) as manager: + band.refresh().run_sync() + + self.assertEqual( + "The instance doesn't exist in the database.", + str(manager.exception), + ) + + def test_error_when_pk_in_none(self) -> None: + """ + Make sure we can't refresh an instance when the primary key value isn't + set. + """ + band = Band.objects().first().run_sync() + assert band is not None + band.id = None + + with self.assertRaises(ValueError) as manager: + band.refresh().run_sync() + + self.assertEqual( + "The instance's primary key value isn't defined.", + str(manager.exception), + ) + + +class TestRefreshWithPrefetch(TableTest): + + tables = [Manager, Band, Concert, Venue] + + def setUp(self): + super().setUp() + + self.manager = Manager({Manager.name: "Guido"}) + self.manager.save().run_sync() + + self.band = Band( + {Band.name: "Pythonistas", Band.manager: self.manager} + ) + self.band.save().run_sync() + + self.concert = Concert({Concert.band_1: self.band}) + self.concert.save().run_sync() + + def test_single_layer(self) -> None: + """ + Make sure ``refresh`` works, when the object used prefetch to get + nested objects (the nested objects should be updated too). + """ + band = ( + Band.objects(Band.manager) + .where(Band.name == "Pythonistas") + .first() + .run_sync() + ) + assert band is not None + + # Modify the data in the database. + self.manager.name = "Guido!!!" + self.manager.save().run_sync() + + # Refresh `band`, and make sure it has the correct data. + band.refresh().run_sync() + self.assertEqual(band.manager.name, "Guido!!!") + + def test_multiple_layers(self) -> None: + """ + Make sure ``refresh`` works when ``prefetch`` was used to fetch objects + multiple layers deep. + """ + concert = ( + Concert.objects(Concert.band_1._.manager) + .where(Concert.band_1._.name == "Pythonistas") + .first() + .run_sync() + ) + assert concert is not None + + # Modify the data in the database. + self.manager.name = "Guido!!!" + self.manager.save().run_sync() + + concert.refresh().run_sync() + self.assertEqual(concert.band_1.manager.name, "Guido!!!") + + def test_updated_foreign_key(self) -> None: + """ + If a foreign key now references a different row, make sure this + is refreshed correctly. + """ + band = ( + Band.objects(Band.manager) + .where(Band.name == "Pythonistas") + .first() + .run_sync() + ) + assert band is not None + + # Assign a different manager to the band + new_manager = Manager({Manager.name: "New Manager"}) + new_manager.save().run_sync() + Band.update({Band.manager: new_manager.id}, force=True).run_sync() + + # Refresh `band`, and make sure it references the new manager. + band.refresh().run_sync() + self.assertEqual(band.manager.id, new_manager.id) + self.assertEqual(band.manager.name, "New Manager") + + def test_foreign_key_set_to_null(self): + """ + Make sure that if the foreign key was set to null, that ``refresh`` + sets the nested object to ``None``. + """ + band = ( + Band.objects(Band.manager) + .where(Band.name == "Pythonistas") + .first() + .run_sync() + ) + assert band is not None + + # Remove the manager from band + Band.update({Band.manager: None}, force=True).run_sync() + + # Refresh `band`, and make sure the foreign key value is now `None`, + # instead of a nested object. + band.refresh().run_sync() + self.assertIsNone(band.manager) + + def test_exception(self) -> None: + """ + We don't currently let the user refresh specific fields from nested + objects - an exception should be raised. + """ + with self.assertRaises(ValueError): + self.concert.refresh(columns=[Concert.band_1._.manager]).run_sync() + + # Shouldn't raise an exception: + self.concert.refresh(columns=[Concert.band_1]).run_sync() + + +class TestRefreshWithLoadJSON(TableTest): + + tables = [RecordingStudio] + + def setUp(self): + super().setUp() + + self.recording_studio = RecordingStudio( + {RecordingStudio.facilities: {"piano": True}} + ) + self.recording_studio.save().run_sync() + + def test_load_json(self): + """ + Make sure we can refresh an object, and load the JSON as a Python + object. + """ + RecordingStudio.update( + {RecordingStudio.facilities: {"electric piano": True}}, + force=True, + ).run_sync() + + # Refresh without load_json: + self.recording_studio.refresh().run_sync() + + self.assertEqual( + # Remove the white space, because some versions of Python add + # whitespace around JSON, and some don't. + self.recording_studio.facilities.replace(" ", ""), + '{"electricpiano":true}', + ) + + # Refresh with load_json: + self.recording_studio.refresh(load_json=True).run_sync() + + self.assertDictEqual( + cast(dict, self.recording_studio.facilities), + {"electric piano": True}, + ) diff --git a/tests/table/test_repr.py b/tests/table/test_repr.py index 055637d31..37a98d37b 100644 --- a/tests/table/test_repr.py +++ b/tests/table/test_repr.py @@ -1,30 +1,15 @@ -from ..base import DBTestCase, postgres_only, sqlite_only -from ..example_app.tables import Manager +from tests.base import DBTestCase +from tests.example_apps.music.tables import Manager class TestTableRepr(DBTestCase): - @postgres_only def test_repr_postgres(self): self.assertEqual( Manager().__repr__(), - '', + "", ) self.insert_row() manager = Manager.objects().first().run_sync() - self.assertEqual( - manager.__repr__(), f"" - ) - - @sqlite_only - def test_repr_sqlite(self): - self.assertEqual( - Manager().__repr__(), - '', - ) - - self.insert_row() - manager = Manager.objects().first().run_sync() - self.assertEqual( - manager.__repr__(), f"" - ) + assert manager is not None + self.assertEqual(manager.__repr__(), f"") diff --git a/tests/table/test_select.py b/tests/table/test_select.py index 27593775e..d41962a01 100644 --- a/tests/table/test_select.py +++ b/tests/table/test_select.py @@ -1,11 +1,27 @@ +import datetime from unittest import TestCase +import pytest + from piccolo.apps.user.tables import BaseUser +from piccolo.columns import Date, Varchar from piccolo.columns.combination import WhereRaw -from piccolo.query.methods.select import Avg, Count, Max, Min, Sum - -from ..base import DBTestCase, postgres_only, sqlite_only -from ..example_app.tables import Band, Concert, Manager +from piccolo.query import OrderByRaw +from piccolo.query.functions.aggregate import Avg, Count, Max, Min, Sum +from piccolo.query.methods.select import SelectRaw +from piccolo.query.mixins import DistinctOnError +from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync +from tests.base import ( + DBTestCase, + engine_is, + engine_version_lt, + engines_only, + engines_skip, + is_running_cockroach, + is_running_sqlite, + sqlite_only, +) +from tests.example_apps.music.tables import Band, Concert, Manager, Venue class TestSelect(DBTestCase): @@ -13,18 +29,32 @@ def test_query_all_columns(self): self.insert_row() response = Band.select().run_sync() - print(f"response = {response}") - self.assertDictEqual( - response[0], - {"id": 1, "name": "Pythonistas", "manager": 1, "popularity": 1000}, - ) + if engine_is("cockroach"): + self.assertDictEqual( + response[0], + { + "id": response[0]["id"], + "name": "Pythonistas", + "manager": response[0]["manager"], + "popularity": 1000, + }, + ) + else: + self.assertDictEqual( + response[0], + { + "id": 1, + "name": "Pythonistas", + "manager": 1, + "popularity": 1000, + }, + ) def test_query_some_columns(self): self.insert_row() response = Band.select(Band.name).run_sync() - print(f"response = {response}") self.assertDictEqual(response[0], {"name": "Pythonistas"}) @@ -65,27 +95,145 @@ def test_where_equals(self): ) self.assertEqual(response, [{"name": "Pythonistas"}]) - def test_where_like(self): - self.insert_rows() - + # check multiple arguments inside WHERE clause response = ( - Band.select(Band.name).where(Band.name.like("Python%")).run_sync() + Band.select(Band.name) + .where(Band.manager.id == 1, Band.popularity == 500) + .run_sync() ) + self.assertEqual(response, []) - print(f"response = {response}") - + # check empty WHERE clause + response = Band.select(Band.name).where().run_sync() self.assertEqual(response, [{"name": "Pythonistas"}]) - def test_where_ilike(self): + @engines_only("postgres", "cockroach") + def test_where_like_postgres(self): + """ + Postgres' LIKE is case sensitive. + """ self.insert_rows() - response = ( - Band.select(Band.name).where(Band.name.ilike("python%")).run_sync() - ) + for like_query in ("Python%", "Pythonistas", "%istas", "%ist%"): + response = ( + Band.select(Band.name) + .where(Band.name.like(like_query)) + .run_sync() + ) - print(f"response = {response}") + self.assertEqual(response, [{"name": "Pythonistas"}]) + + for like_query in ( + "PyThonISTAs", + "PYth%", + "%ISTAS", + "%Ist%", + "PYTHONISTAS", + ): + response = ( + Band.select(Band.name) + .where(Band.name.like(like_query)) + .run_sync() + ) - self.assertEqual(response, [{"name": "Pythonistas"}]) + self.assertEqual(response, []) + + @sqlite_only + def test_where_like_sqlite(self): + """ + SQLite's LIKE is case insensitive for ASCII characters, + i.e. a == 'A' -> True. + """ + self.insert_rows() + + for like_query in ( + "Python%", + "Pythonistas", + "%istas", + "%ist%", + "PYTHONISTAS", + ): + response = ( + Band.select(Band.name) + .where(Band.name.like(like_query)) + .run_sync() + ) + + self.assertEqual(response, [{"name": "Pythonistas"}]) + + for like_query in ( + "xyz", + "XYZ%", + "%xyz", + "%xyz%", + ): + response = ( + Band.select(Band.name) + .where(Band.name.like(like_query)) + .run_sync() + ) + + self.assertEqual(response, []) + + @sqlite_only + def test_where_ilike_sqlite(self): + """ + SQLite doesn't support ILIKE, so it's just a proxy to LIKE. We still + have a test to make sure it proxies correctly. + """ + self.insert_rows() + + for ilike_query in ( + "Python%", + "Pythonistas", + "pythonistas", + "PytHonIStas", + "%istas", + "%ist%", + "%IST%", + ): + self.assertEqual( + Band.select(Band.name) + .where(Band.name.ilike(ilike_query)) + .run_sync(), + Band.select(Band.name) + .where(Band.name.like(ilike_query)) + .run_sync(), + ) + + @engines_only("postgres", "cockroach") + def test_where_ilike_postgres(self): + """ + Only Postgres has ILIKE - it's not in the SQL standard. It's for + case insensitive matching, i.e. 'Foo' == 'foo' -> True. + """ + self.insert_rows() + + for ilike_query in ( + "Python%", + "Pythonistas", + "pythonistas", + "PytHonIStas", + "%istas", + "%ist%", + "%IST%", + ): + response = ( + Band.select(Band.name) + .where(Band.name.ilike(ilike_query)) + .run_sync() + ) + + self.assertEqual(response, [{"name": "Pythonistas"}]) + + for ilike_query in ("Pythonistas1", "%123", "%xyz%"): + response = ( + Band.select(Band.name) + .where(Band.name.ilike(ilike_query)) + .run_sync() + ) + + self.assertEqual(response, []) def test_where_not_like(self): self.insert_rows() @@ -97,8 +245,6 @@ def test_where_not_like(self): .run_sync() ) - print(f"response = {response}") - self.assertEqual( response, [{"name": "CSharps"}, {"name": "Rustaceans"}] ) @@ -110,10 +256,66 @@ def test_where_greater_than(self): Band.select(Band.name).where(Band.popularity > 1000).run_sync() ) - print(f"response = {response}") - self.assertEqual(response, [{"name": "Rustaceans"}]) + def test_is_in(self): + self.insert_rows() + + response = ( + Band.select(Band.name) + .where(Band.manager._.name.is_in(["Guido"])) + .run_sync() + ) + + self.assertListEqual(response, [{"name": "Pythonistas"}]) + + def test_is_in_subquery(self): + self.insert_rows() + + # This is a contrived example, just for testing. + response = ( + Band.select(Band.name) + .where( + Band.manager.is_in( + Manager.select(Manager.id).where(Manager.name == "Guido") + ) + ) + .run_sync() + ) + + self.assertListEqual(response, [{"name": "Pythonistas"}]) + + def test_not_in(self): + self.insert_rows() + + response = ( + Band.select(Band.name) + .where(Band.manager._.name.not_in(["Guido"])) + .run_sync() + ) + + self.assertListEqual( + response, [{"name": "Rustaceans"}, {"name": "CSharps"}] + ) + + def test_not_in_subquery(self): + self.insert_rows() + + # This is a contrived example, just for testing. + response = ( + Band.select(Band.name) + .where( + Band.manager.not_in( + Manager.select(Manager.id).where(Manager.name == "Guido") + ) + ) + .run_sync() + ) + + self.assertListEqual( + response, [{"name": "Rustaceans"}, {"name": "CSharps"}] + ) + def test_where_is_null(self): self.insert_rows() @@ -128,6 +330,15 @@ def test_where_is_null(self): response = query.run_sync() self.assertEqual(response, [{"name": "Managerless"}]) + def test_where_bool(self): + """ + If passing a boolean into a where clause, an exception should be + raised. It's possible for a user to do this by accident, for example + ``where(Band.has_drummer is None)``, which evaluates to a boolean. + """ + with self.assertRaises(ValueError): + Band.select().where(False) # type: ignore + def test_where_is_not_null(self): self.insert_rows() @@ -159,8 +370,6 @@ def test_where_greater_equal_than(self): .run_sync() ) - print(f"response = {response}") - self.assertEqual( response, [{"name": "Pythonistas"}, {"name": "Rustaceans"}] ) @@ -172,8 +381,6 @@ def test_where_less_than(self): Band.select(Band.name).where(Band.popularity < 1000).run_sync() ) - print(f"response = {response}") - self.assertEqual(response, [{"name": "CSharps"}]) def test_where_less_equal_than(self): @@ -183,8 +390,6 @@ def test_where_less_equal_than(self): Band.select(Band.name).where(Band.popularity <= 1000).run_sync() ) - print(f"response = {response}") - self.assertEqual( response, [{"name": "Pythonistas"}, {"name": "CSharps"}] ) @@ -201,8 +406,6 @@ def test_where_raw(self): .run_sync() ) - print(f"response = {response}") - self.assertEqual(response, [{"name": "Pythonistas"}]) def test_where_raw_with_args(self): @@ -218,8 +421,6 @@ def test_where_raw_with_args(self): .run_sync() ) - print(f"response = {response}") - self.assertEqual(response, [{"name": "Pythonistas"}]) def test_where_raw_combined_with_where(self): @@ -236,8 +437,6 @@ def test_where_raw_combined_with_where(self): .run_sync() ) - print(f"response = {response}") - self.assertEqual( response, [{"name": "Pythonistas"}, {"name": "Rustaceans"}] ) @@ -251,8 +450,6 @@ def test_where_and(self): .run_sync() ) - print(f"response = {response}") - self.assertEqual(response, [{"name": "Pythonistas"}]) def test_where_or(self): @@ -265,12 +462,11 @@ def test_where_or(self): .run_sync() ) - print(f"response = {response}") - self.assertEqual( response, [{"name": "CSharps"}, {"name": "Rustaceans"}] ) + @engines_skip("cockroach") def test_multiple_where(self): """ Test that chaining multiple where clauses works results in an AND. @@ -285,11 +481,10 @@ def test_multiple_where(self): response = query.run_sync() - print(f"response = {response}") - self.assertEqual(response, [{"name": "Rustaceans"}]) - self.assertTrue("AND" in query.__str__()) + self.assertIn("AND", query.__str__()) + @engines_skip("cockroach") def test_complex_where(self): """ Test a complex where clause - combining AND, and OR. @@ -307,8 +502,6 @@ def test_complex_where(self): response = query.run_sync() - print(f"response = {response}") - self.assertEqual( response, [{"name": "CSharps"}, {"name": "Rustaceans"}] ) @@ -320,11 +513,9 @@ def test_limit(self): Band.select(Band.name).order_by(Band.name).limit(1).run_sync() ) - print(f"response = {response}") - self.assertEqual(response, [{"name": "CSharps"}]) - @postgres_only + @engines_only("postgres", "cockroach") def test_offset_postgres(self): self.insert_rows() @@ -332,8 +523,6 @@ def test_offset_postgres(self): Band.select(Band.name).order_by(Band.name).offset(1).run_sync() ) - print(f"response = {response}") - self.assertEqual( response, [{"name": "Pythonistas"}, {"name": "Rustaceans"}] ) @@ -353,8 +542,6 @@ def test_offset_sqlite(self): query = query.limit(5) response = query.run_sync() - print(f"response = {response}") - self.assertEqual( response, [{"name": "Pythonistas"}, {"name": "Rustaceans"}] ) @@ -366,64 +553,56 @@ def test_first(self): Band.select(Band.name).order_by(Band.name).first().run_sync() ) - print(f"response = {response}") - self.assertEqual(response, {"name": "CSharps"}) - def test_order_by_ascending(self): + def test_count(self): self.insert_rows() - response = ( - Band.select(Band.name).order_by(Band.name).limit(1).run_sync() - ) - - print(f"response = {response}") + response = Band.count().where(Band.name == "Pythonistas").run_sync() - self.assertEqual(response, [{"name": "CSharps"}]) + self.assertEqual(response, 1) - def test_order_by_decending(self): + def test_distinct(self): + """ + Make sure the distinct clause works. + """ self.insert_rows() - - response = ( - Band.select(Band.name) - .order_by(Band.name, ascending=False) - .limit(1) - .run_sync() - ) - - print(f"response = {response}") - - self.assertEqual(response, [{"name": "Rustaceans"}]) - - def test_count(self): self.insert_rows() - response = Band.count().where(Band.name == "Pythonistas").run_sync() + query = Band.select(Band.name).where(Band.name == "Pythonistas") + self.assertNotIn("DISTINCT", query.__str__()) - print(f"response = {response}") + response = query.run_sync() + self.assertEqual( + response, [{"name": "Pythonistas"}, {"name": "Pythonistas"}] + ) - self.assertEqual(response, 1) + query = query.distinct() + self.assertIn("DISTINCT", query.__str__()) - def test_distinct(self): + response = query.run_sync() + self.assertEqual(response, [{"name": "Pythonistas"}]) + + def test_distinct_on(self): """ - Make sure the distinct clause works. + Make sure the distinct clause works, with the ``on`` param. """ self.insert_rows() self.insert_rows() query = Band.select(Band.name).where(Band.name == "Pythonistas") - self.assertTrue("DISTINCT" not in query.__str__()) + self.assertNotIn("DISTINCT", query.__str__()) response = query.run_sync() - self.assertTrue( - response == [{"name": "Pythonistas"}, {"name": "Pythonistas"}] + self.assertEqual( + response, [{"name": "Pythonistas"}, {"name": "Pythonistas"}] ) query = query.distinct() - self.assertTrue("DISTINCT" in query.__str__()) + self.assertIn("DISTINCT", query.__str__()) response = query.run_sync() - self.assertTrue(response == [{"name": "Pythonistas"}]) + self.assertEqual(response, [{"name": "Pythonistas"}]) def test_count_group_by(self): """ @@ -439,13 +618,13 @@ def test_count_group_by(self): .run_sync() ) - self.assertTrue( - response - == [ + self.assertEqual( + response, + [ {"name": "CSharps", "count": 2}, {"name": "Pythonistas", "count": 2}, {"name": "Rustaceans", "count": 2}, - ] + ], ) def test_count_with_alias_group_by(self): @@ -462,13 +641,13 @@ def test_count_with_alias_group_by(self): .run_sync() ) - self.assertTrue( - response - == [ + self.assertEqual( + response, + [ {"name": "CSharps", "total": 2}, {"name": "Pythonistas", "total": 2}, {"name": "Rustaceans", "total": 2}, - ] + ], ) def test_count_with_as_alias_group_by(self): @@ -485,13 +664,13 @@ def test_count_with_as_alias_group_by(self): .run_sync() ) - self.assertTrue( - response - == [ + self.assertEqual( + response, + [ {"name": "CSharps", "total": 2}, {"name": "Pythonistas", "total": 2}, {"name": "Rustaceans", "total": 2}, - ] + ], ) def test_count_column_group_by(self): @@ -525,14 +704,14 @@ def test_count_column_group_by(self): # differently when sorting. response = sorted(response, key=lambda x: x["manager.name"] or "") - self.assertTrue( - response - == [ + self.assertEqual( + response, + [ {"manager.name": None, "count": 0}, {"manager.name": "Graydon", "count": 2}, {"manager.name": "Guido", "count": 2}, {"manager.name": "Mads", "count": 2}, - ] + ], ) # This time the nulls should be counted, as we omit the column argument @@ -546,22 +725,23 @@ def test_count_column_group_by(self): response = sorted(response, key=lambda x: x["manager.name"] or "") - self.assertTrue( - response - == [ + self.assertEqual( + response, + [ {"manager.name": None, "count": 1}, {"manager.name": "Graydon", "count": 2}, {"manager.name": "Guido", "count": 2}, {"manager.name": "Mads", "count": 2}, - ] + ], ) def test_avg(self): self.insert_rows() response = Band.select(Avg(Band.popularity)).first().run_sync() + assert response is not None - self.assertTrue(float(response["avg"]) == 1003.3333333333334) + self.assertEqual(float(response["avg"]), 1003.3333333333334) def test_avg_alias(self): self.insert_rows() @@ -571,10 +751,9 @@ def test_avg_alias(self): .first() .run_sync() ) + assert response is not None - self.assertTrue( - float(response["popularity_avg"]) == 1003.3333333333334 - ) + self.assertEqual(float(response["popularity_avg"]), 1003.3333333333334) def test_avg_as_alias_method(self): self.insert_rows() @@ -584,10 +763,9 @@ def test_avg_as_alias_method(self): .first() .run_sync() ) + assert response is not None - self.assertTrue( - float(response["popularity_avg"]) == 1003.3333333333334 - ) + self.assertEqual(float(response["popularity_avg"]), 1003.3333333333334) def test_avg_with_where_clause(self): self.insert_rows() @@ -598,8 +776,9 @@ def test_avg_with_where_clause(self): .first() .run_sync() ) + assert response is not None - self.assertTrue(response["avg"] == 1500) + self.assertEqual(response["avg"], 1500) def test_avg_alias_with_where_clause(self): """ @@ -614,8 +793,9 @@ def test_avg_alias_with_where_clause(self): .first() .run_sync() ) + assert response is not None - self.assertTrue(response["popularity_avg"] == 1500) + self.assertEqual(response["popularity_avg"], 1500) def test_avg_as_alias_method_with_where_clause(self): """ @@ -630,15 +810,17 @@ def test_avg_as_alias_method_with_where_clause(self): .first() .run_sync() ) + assert response is not None - self.assertTrue(response["popularity_avg"] == 1500) + self.assertEqual(response["popularity_avg"], 1500) def test_max(self): self.insert_rows() response = Band.select(Max(Band.popularity)).first().run_sync() + assert response is not None - self.assertTrue(response["max"] == 2000) + self.assertEqual(response["max"], 2000) def test_max_alias(self): self.insert_rows() @@ -648,8 +830,9 @@ def test_max_alias(self): .first() .run_sync() ) + assert response is not None - self.assertTrue(response["popularity_max"] == 2000) + self.assertEqual(response["popularity_max"], 2000) def test_max_as_alias_method(self): self.insert_rows() @@ -659,15 +842,17 @@ def test_max_as_alias_method(self): .first() .run_sync() ) + assert response is not None - self.assertTrue(response["popularity_max"] == 2000) + self.assertEqual(response["popularity_max"], 2000) def test_min(self): self.insert_rows() response = Band.select(Min(Band.popularity)).first().run_sync() + assert response is not None - self.assertTrue(response["min"] == 10) + self.assertEqual(response["min"], 10) def test_min_alias(self): self.insert_rows() @@ -677,8 +862,9 @@ def test_min_alias(self): .first() .run_sync() ) + assert response is not None - self.assertTrue(response["popularity_min"] == 10) + self.assertEqual(response["popularity_min"], 10) def test_min_as_alias_method(self): self.insert_rows() @@ -688,15 +874,17 @@ def test_min_as_alias_method(self): .first() .run_sync() ) + assert response is not None - self.assertTrue(response["popularity_min"] == 10) + self.assertEqual(response["popularity_min"], 10) def test_sum(self): self.insert_rows() response = Band.select(Sum(Band.popularity)).first().run_sync() + assert response is not None - self.assertTrue(response["sum"] == 3010) + self.assertEqual(response["sum"], 3010) def test_sum_alias(self): self.insert_rows() @@ -706,8 +894,9 @@ def test_sum_alias(self): .first() .run_sync() ) + assert response is not None - self.assertTrue(response["popularity_sum"] == 3010) + self.assertEqual(response["popularity_sum"], 3010) def test_sum_as_alias_method(self): self.insert_rows() @@ -717,8 +906,9 @@ def test_sum_as_alias_method(self): .first() .run_sync() ) + assert response is not None - self.assertTrue(response["popularity_sum"] == 3010) + self.assertEqual(response["popularity_sum"], 3010) def test_sum_with_where_clause(self): self.insert_rows() @@ -729,8 +919,9 @@ def test_sum_with_where_clause(self): .first() .run_sync() ) + assert response is not None - self.assertTrue(response["sum"] == 3000) + self.assertEqual(response["sum"], 3000) def test_sum_alias_with_where_clause(self): """ @@ -745,8 +936,9 @@ def test_sum_alias_with_where_clause(self): .first() .run_sync() ) + assert response is not None - self.assertTrue(response["popularity_sum"] == 3000) + self.assertEqual(response["popularity_sum"], 3000) def test_sum_as_alias_method_with_where_clause(self): """ @@ -761,8 +953,9 @@ def test_sum_as_alias_method_with_where_clause(self): .first() .run_sync() ) + assert response is not None - self.assertTrue(response["popularity_sum"] == 3000) + self.assertEqual(response["popularity_sum"], 3000) def test_chain_different_functions(self): self.insert_rows() @@ -772,9 +965,10 @@ def test_chain_different_functions(self): .first() .run_sync() ) + assert response is not None - self.assertTrue(float(response["avg"]) == 1003.3333333333334) - self.assertTrue(response["sum"] == 3010) + self.assertEqual(float(response["avg"]), 1003.3333333333334) + self.assertEqual(response["sum"], 3010) def test_chain_different_functions_alias(self): self.insert_rows() @@ -787,19 +981,10 @@ def test_chain_different_functions_alias(self): .first() .run_sync() ) + assert response is not None - self.assertTrue( - float(response["popularity_avg"]) == 1003.3333333333334 - ) - self.assertTrue(response["popularity_sum"] == 3010) - - def test_avg_validation(self): - with self.assertRaises(ValueError): - Band.select(Avg(Band.name)).run_sync() - - def test_sum_validation(self): - with self.assertRaises(ValueError): - Band.select(Sum(Band.name)).run_sync() + self.assertEqual(float(response["popularity_avg"]), 1003.3333333333334) + self.assertEqual(response["popularity_sum"], 3010) def test_columns(self): """ @@ -815,7 +1000,8 @@ def test_columns(self): .first() .run_sync() ) - self.assertTrue(response == {"name": "Pythonistas"}) + assert response is not None + self.assertDictEqual(response, {"name": "Pythonistas"}) # Multiple calls to 'columns' should be additive. response = ( @@ -826,14 +1012,23 @@ def test_columns(self): .first() .run_sync() ) - self.assertTrue(response == {"id": 1, "name": "Pythonistas"}) + assert response is not None + + if engine_is("cockroach"): + self.assertEqual( + response, {"id": response["id"], "name": "Pythonistas"} + ) + else: + self.assertEqual(response, {"id": 1, "name": "Pythonistas"}) def test_call_chain(self): """ Make sure the call chain lengths are the correct size. """ self.assertEqual(len(Concert.band_1.name._meta.call_chain), 1) - self.assertEqual(len(Concert.band_1.manager.name._meta.call_chain), 2) + self.assertEqual( + len(Concert.band_1._.manager._.name._meta.call_chain), 2 + ) def test_as_alias(self): """ @@ -869,6 +1064,62 @@ def test_as_alias_with_where_clause(self): response, [{"name": "Pythonistas", "manager_name": "Guido"}] ) + @pytest.mark.skipif( + is_running_sqlite() and engine_version_lt(3.35), + reason="SQLite doesn't have math functions in this version.", + ) + @pytest.mark.skipif( + is_running_cockroach(), + reason=( + "Cockroach raises an error when trying to use the log function." + ), + ) + def test_select_raw(self): + """ + Make sure ``SelectRaw`` can be used in select queries. + """ + self.insert_row() + response = Band.select( + Band.name, SelectRaw("round(log(popularity)) AS popularity_log") + ).run_sync() + self.assertListEqual( + response, [{"name": "Pythonistas", "popularity_log": 3.0}] + ) + + @pytest.mark.skipif( + is_running_sqlite(), + reason="SQLite doesn't support SELECT ... FOR UPDATE.", + ) + def test_lock_rows(self): + """ + Make sure the for_update clause works. + """ + self.insert_rows() + + query = Band.select() + self.assertNotIn("FOR UPDATE", query.__str__()) + + query = query.lock_rows() + self.assertTrue(query.__str__().endswith("FOR UPDATE")) + + query = query.lock_rows(lock_strength="KEY SHARE") + self.assertTrue(query.__str__().endswith("FOR KEY SHARE")) + + query = query.lock_rows(skip_locked=True) + self.assertTrue(query.__str__().endswith("FOR UPDATE SKIP LOCKED")) + + query = query.lock_rows(nowait=True) + self.assertTrue(query.__str__().endswith("FOR UPDATE NOWAIT")) + + query = query.lock_rows(of=(Band,)) + self.assertTrue(query.__str__().endswith('FOR UPDATE OF "band"')) + + with self.assertRaises(TypeError): + query = query.lock_rows(skip_locked=True, nowait=True) + + response = query.run_sync() + assert response is not None + class TestSelectSecret(TestCase): def setUp(self): @@ -886,4 +1137,373 @@ def test_secret(self): user.save().run_sync() user_dict = BaseUser.select(exclude_secrets=True).first().run_sync() - self.assertTrue("password" not in user_dict.keys()) + assert user_dict is not None + self.assertNotIn("password", user_dict.keys()) + + +class TestSelectSecretParameter(TestCase): + def setUp(self): + Venue.create_table().run_sync() + + def tearDown(self): + Venue.alter().drop_table().run_sync() + + def test_secret_parameter(self): + """ + Make sure that fields with parameter ``secret=True`` are omitted + from the response when requested. + """ + venue = Venue(name="The Garage", capacity=1000) + venue.save().run_sync() + + venue_dict = Venue.select(exclude_secrets=True).first().run_sync() + assert venue_dict is not None + if engine_is("cockroach"): + self.assertTrue( + venue_dict, {"id": venue_dict["id"], "name": "The Garage"} + ) + else: + self.assertTrue(venue_dict, {"id": 1, "name": "The Garage"}) + self.assertNotIn("capacity", venue_dict.keys()) + + +class TestSelectOrderBy(TestCase): + """ + We use TestCase, rather than DBTestCase, as we want a lot of data to test + with. + """ + + def setUp(self): + """ + Create tables and lots of test data. + """ + create_db_tables_sync(Band, Manager) + + data = [ + { + "band_name": "Pythonistas", + "manager_name": "Guido", + "popularity": 1000, + }, + { + "band_name": "Rustaceans", + "manager_name": "Graydon", + "popularity": 800, + }, + { + "band_name": "C-Sharps", + "manager_name": "Anders", + "popularity": 800, + }, + { + "band_name": "Rubyists", + "manager_name": "Matz", + "popularity": 820, + }, + ] + + for item in data: + manager = ( + Manager.objects().create(name=item["manager_name"]).run_sync() + ) + + Band.objects().create( + name=item["band_name"], + manager=manager, + popularity=item["popularity"], + ).run_sync() + + def tearDown(self): + drop_db_tables_sync(Band, Manager) + + def test_ascending(self): + response = Band.select(Band.name).order_by(Band.name).run_sync() + + self.assertEqual( + response, + [ + {"name": "C-Sharps"}, + {"name": "Pythonistas"}, + {"name": "Rubyists"}, + {"name": "Rustaceans"}, + ], + ) + + def test_descending(self): + response = ( + Band.select(Band.name) + .order_by(Band.name, ascending=False) + .run_sync() + ) + + self.assertEqual( + response, + [ + {"name": "Rustaceans"}, + {"name": "Rubyists"}, + {"name": "Pythonistas"}, + {"name": "C-Sharps"}, + ], + ) + + def test_string(self): + """ + Make sure strings can be used to identify columns if the user prefers. + """ + response = Band.select(Band.name).order_by("name").run_sync() + + self.assertEqual( + response, + [ + {"name": "C-Sharps"}, + {"name": "Pythonistas"}, + {"name": "Rubyists"}, + {"name": "Rustaceans"}, + ], + ) + + def test_string_unrecognised(self): + """ + Make sure an unrecognised column name raises an Exception. + """ + with self.assertRaises(ValueError) as manager: + Band.select(Band.name).order_by("foo") + + self.assertEqual( + manager.exception.__str__(), + "No matching column found with name == foo", + ) + + def test_multiple_columns_ascending(self): + """ + Make sure we can order by multiple columns. + """ + response = ( + Band.select(Band.popularity, Band.name) + .order_by(Band.popularity, Band.name) + .run_sync() + ) + + self.assertEqual( + response, + [ + {"popularity": 800, "name": "C-Sharps"}, + {"popularity": 800, "name": "Rustaceans"}, + {"popularity": 820, "name": "Rubyists"}, + {"popularity": 1000, "name": "Pythonistas"}, + ], + ) + + def test_multiple_columns_descending(self): + """ + Make sure we can order by multiple columns, descending. + """ + response = ( + Band.select(Band.popularity, Band.name) + .order_by(Band.popularity, Band.name, ascending=False) + .run_sync() + ) + + self.assertEqual( + response, + [ + {"popularity": 1000, "name": "Pythonistas"}, + {"popularity": 820, "name": "Rubyists"}, + {"popularity": 800, "name": "Rustaceans"}, + {"popularity": 800, "name": "C-Sharps"}, + ], + ) + + def test_join(self): + """ + Make sure that we can order using columns in related tables. + """ + response = ( + Band.select(Band.manager.name.as_alias("manager_name"), Band.name) + .order_by(Band.manager.name) + .run_sync() + ) + self.assertEqual( + response, + [ + {"manager_name": "Anders", "name": "C-Sharps"}, + {"manager_name": "Graydon", "name": "Rustaceans"}, + {"manager_name": "Guido", "name": "Pythonistas"}, + {"manager_name": "Matz", "name": "Rubyists"}, + ], + ) + + def test_ascending_descending(self): + """ + Make sure we can combine ascending and descending. + """ + response = ( + Band.select(Band.popularity, Band.name) + .order_by(Band.popularity) + .order_by(Band.name, ascending=False) + .run_sync() + ) + + self.assertEqual( + response, + [ + {"popularity": 800, "name": "Rustaceans"}, + {"popularity": 800, "name": "C-Sharps"}, + {"popularity": 820, "name": "Rubyists"}, + {"popularity": 1000, "name": "Pythonistas"}, + ], + ) + + def test_order_by_raw(self): + """ + Maker sure ``OrderByRaw`` can be used, to order by anything the user + wants. + """ + response = ( + Band.select(Band.name).order_by(OrderByRaw("name")).run_sync() + ) + + self.assertEqual( + response, + [ + {"name": "C-Sharps"}, + {"name": "Pythonistas"}, + {"name": "Rubyists"}, + {"name": "Rustaceans"}, + ], + ) + + +class Album(Table): + band = Varchar() + title = Varchar() + release_date = Date() + + +class TestDistinctOn(TestCase): + def setUp(self): + Album.create_table().run_sync() + + def tearDown(self): + Album.alter().drop_table().run_sync() + + @engines_only("postgres", "cockroach") + def test_distinct_on(self): + """ + Make sure the ``distinct`` method can be used to create a + ``DISTINCT ON`` clause. + """ + Album.insert( + Album( + { + Album.band: "Pythonistas", + Album.title: "P1", + Album.release_date: datetime.date( + year=2022, month=1, day=1 + ), + } + ), + Album( + { + Album.band: "Pythonistas", + Album.title: "P2", + Album.release_date: datetime.date( + year=2023, month=1, day=1 + ), + } + ), + Album( + { + Album.band: "Rustaceans", + Album.title: "R1", + Album.release_date: datetime.date( + year=2022, month=1, day=1 + ), + } + ), + Album( + { + Album.band: "Rustaceans", + Album.title: "R2", + Album.release_date: datetime.date( + year=2023, month=1, day=1 + ), + } + ), + Album( + { + Album.band: "C-Sharps", + Album.title: "C1", + Album.release_date: datetime.date( + year=2022, month=1, day=1 + ), + } + ), + Album( + { + Album.band: "C-Sharps", + Album.title: "C2", + Album.release_date: datetime.date( + year=2023, month=1, day=1 + ), + } + ), + ).run_sync() + + # Get the most recent album for each band. + query = ( + Album.select(Album.band, Album.title) + .distinct(on=[Album.band]) + .order_by(Album.band) + .order_by(Album.release_date, ascending=False) + ) + self.assertIn("DISTINCT ON", query.__str__()) + response = query.run_sync() + + self.assertEqual( + response, + [ + {"band": "C-Sharps", "title": "C2"}, + {"band": "Pythonistas", "title": "P2"}, + {"band": "Rustaceans", "title": "R2"}, + ], + ) + + @engines_only("sqlite") + def test_distinct_on_sqlite(self): + """ + SQLite doesn't support ``DISTINCT ON``, so a ``ValueError`` should be + raised. + """ + with self.assertRaises(NotImplementedError) as manager: + Album.select().distinct(on=[Album.band]) + + self.assertEqual( + manager.exception.__str__(), + "SQLite doesn't support DISTINCT ON", + ) + + @engines_only("postgres", "cockroach") + def test_distinct_on_error(self): + """ + If we pass in something other than a sequence of columns, it should + raise a ValueError. + """ + with self.assertRaises(ValueError) as manager: + Album.select().distinct(on=Album.band) # type: ignore + + self.assertEqual( + manager.exception.__str__(), + "`on` must be a sequence of `Column` instances", + ) + + @engines_only("postgres", "cockroach") + def test_distinct_on_order_by_error(self): + """ + The first column passed to `order_by` must match the first column + passed to `on`, otherwise an exception is raised. + """ + with self.assertRaises(DistinctOnError): + Album.select().distinct(on=[Album.band]).order_by( + Album.release_date + ).run_sync() diff --git a/tests/table/test_str.py b/tests/table/test_str.py index c53289c58..9255331de 100644 --- a/tests/table/test_str.py +++ b/tests/table/test_str.py @@ -1,27 +1,46 @@ from unittest import TestCase -from ..example_app.tables import Manager +from piccolo.apps.playground.commands.run import Genre, Manager class TestTableStr(TestCase): - def test_str(self): + def test_all_attributes(self): self.assertEqual( Manager._table_str(), ( "class Manager(Table, tablename='manager'):\n" - " id = Serial(null=False, primary_key=True, unique=False, index=False, index_method=IndexMethod.btree, choices=None)\n" # noqa: E501 - " name = Varchar(length=50, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None)\n" # noqa: E501 + " id = Serial(null=False, primary_key=True, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name='id', secret=False)\n" # noqa: E501 + " name = Varchar(length=50, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name=None, secret=False)\n" # noqa: E501 ), ) + def test_abbreviated(self): self.assertEqual( Manager._table_str(abbreviated=True), ( "class Manager(Table):\n" - " id = Serial()\n" + " id = Serial(primary_key=True)\n" + " name = Varchar(length=50)\n" + ), + ) + + def test_m2m(self): + """ + Make sure M2M relationships appear in the Table string. + """ + + self.assertEqual( + Genre._table_str(abbreviated=True), + ( + "class Genre(Table):\n" + " id = Serial(primary_key=True)\n" " name = Varchar()\n" + " bands = M2M(GenreToBand)\n" ), ) - # We should also be able to print it directly. + def test_print(self): + """ + Make sure we can print it directly without any errors. + """ print(Manager) diff --git a/tests/table/test_table_exists.py b/tests/table/test_table_exists.py index 1ee789e19..6b31afa00 100644 --- a/tests/table/test_table_exists.py +++ b/tests/table/test_table_exists.py @@ -1,15 +1,41 @@ from unittest import TestCase -from ..example_app.tables import Manager +from piccolo.columns import Varchar +from piccolo.schema import SchemaManager +from piccolo.table import Table +from tests.base import engines_skip +from tests.example_apps.music.tables import Manager class TestTableExists(TestCase): def setUp(self): Manager.create_table().run_sync() + def tearDown(self): + Manager.alter().drop_table().run_sync() + def test_table_exists(self): response = Manager.table_exists().run_sync() - self.assertTrue(response is True) + self.assertTrue(response) + + +class Band(Table, schema="schema_1"): + name = Varchar() + + +@engines_skip("sqlite") +class TestTableExistsSchema(TestCase): + def setUp(self): + Band.create_table(auto_create_schema=True).run_sync() def tearDown(self): - Manager.alter().drop_table().run_sync() + SchemaManager().drop_schema( + "schema_1", if_exists=True, cascade=True + ).run_sync() + + def test_table_exists(self): + """ + Make sure it works correctly if the table is in a Postgres schema. + """ + response = Band.table_exists().run_sync() + self.assertTrue(response) diff --git a/tests/table/test_update.py b/tests/table/test_update.py index ba5b33a56..9599cb465 100644 --- a/tests/table/test_update.py +++ b/tests/table/test_update.py @@ -1,5 +1,30 @@ -from ..base import DBTestCase -from ..example_app.tables import Band, Poster +import dataclasses +import datetime +from typing import Any +from unittest import TestCase + +import pytest + +from piccolo.columns.base import Column +from piccolo.columns.column_types import ( + Date, + Integer, + Interval, + Text, + Timestamp, + Timestamptz, + Varchar, +) +from piccolo.querystring import QueryString +from piccolo.table import Table +from tests.base import ( + DBTestCase, + engine_version_lt, + engines_skip, + is_running_sqlite, + sqlite_only, +) +from tests.example_apps.music.tables import Band, Manager class TestUpdate(DBTestCase): @@ -88,156 +113,628 @@ def test_update_values_with_kwargs(self): self.check_response() + @pytest.mark.skipif( + is_running_sqlite() and engine_version_lt(3.35), + reason="SQLite version not supported", + ) + def test_update_returning(self): + """ + Make sure update works with the `returning` clause. + """ + self.insert_rows() -class TestIntUpdateOperators(DBTestCase): - def test_add(self): - self.insert_row() - - Band.update({Band.popularity: Band.popularity + 10}).run_sync() - - response = Band.select(Band.popularity).first().run_sync() - - self.assertEqual(response["popularity"], 1010) - - def test_add_column(self): - self.insert_row() - - Band.update( - {Band.popularity: Band.popularity + Band.popularity} - ).run_sync() - - response = Band.select(Band.popularity).first().run_sync() - - self.assertEqual(response["popularity"], 2000) - - def test_radd(self): - self.insert_row() - - Band.update({Band.popularity: 10 + Band.popularity}).run_sync() - - response = Band.select(Band.popularity).first().run_sync() - - self.assertEqual(response["popularity"], 1010) - - def test_sub(self): - self.insert_row() - - Band.update({Band.popularity: Band.popularity - 10}).run_sync() + response = ( + Band.update({Band.name: "Pythonistas 2"}) + .where(Band.name == "Pythonistas") + .returning(Band.name) + .run_sync() + ) - response = Band.select(Band.popularity).first().run_sync() + self.assertEqual(response, [{"name": "Pythonistas 2"}]) - self.assertEqual(response["popularity"], 990) + @pytest.mark.skipif( + is_running_sqlite() and engine_version_lt(3.35), + reason="SQLite version not supported", + ) + def test_update_returning_alias(self): + """ + Make sure update works with the `returning` clause. + """ + self.insert_rows() - def test_rsub(self): - self.insert_row() + response = ( + Band.update({Band.name: "Pythonistas 2"}) + .where(Band.name == "Pythonistas") + .returning(Band.name.as_alias("band name")) + .run_sync() + ) - Band.update({Band.popularity: 1100 - Band.popularity}).run_sync() + self.assertEqual(response, [{"band name": "Pythonistas 2"}]) + + +############################################################################### +# Test operators + + +class MyTable(Table): + integer = Integer(null=True) + other_integer = Integer(null=True, default=5) + timestamp = Timestamp(null=True) + timestamptz = Timestamptz(null=True) + date = Date(null=True) + interval = Interval(null=True) + varchar = Varchar(null=True) + text = Text(null=True) + + +INITIAL_DATETIME = datetime.datetime( + year=2022, month=1, day=1, hour=21, minute=0 +) +INITIAL_INTERVAL = datetime.timedelta(days=1, hours=1, minutes=1) + +DATETIME_DELTA = datetime.timedelta( + days=1, hours=1, minutes=1, seconds=30, microseconds=1000 +) +DATE_DELTA = datetime.timedelta(days=1) + + +@dataclasses.dataclass +class OperatorTestCase: + description: str + column: Column + initial: Any + querystring: QueryString + expected: Any + + +TEST_CASES = [ + # Text + OperatorTestCase( + description="Add Text", + column=MyTable.text, + initial="Pythonistas", + querystring=MyTable.text + "!!!", + expected="Pythonistas!!!", + ), + OperatorTestCase( + description="Add Text columns", + column=MyTable.text, + initial="Pythonistas", + querystring=MyTable.text + MyTable.text, + expected="PythonistasPythonistas", + ), + OperatorTestCase( + description="Reverse add Text", + column=MyTable.text, + initial="Pythonistas", + querystring="!!!" + MyTable.text, + expected="!!!Pythonistas", + ), + OperatorTestCase( + description="Text is null", + column=MyTable.text, + initial=None, + querystring=MyTable.text + "!!!", + expected=None, + ), + OperatorTestCase( + description="Reverse Text is null", + column=MyTable.text, + initial=None, + querystring="!!!" + MyTable.text, + expected=None, + ), + # Varchar + OperatorTestCase( + description="Add Varchar", + column=MyTable.varchar, + initial="Pythonistas", + querystring=MyTable.varchar + "!!!", + expected="Pythonistas!!!", + ), + OperatorTestCase( + description="Add Varchar columns", + column=MyTable.varchar, + initial="Pythonistas", + querystring=MyTable.varchar + MyTable.varchar, + expected="PythonistasPythonistas", + ), + OperatorTestCase( + description="Reverse add Varchar", + column=MyTable.varchar, + initial="Pythonistas", + querystring="!!!" + MyTable.varchar, + expected="!!!Pythonistas", + ), + OperatorTestCase( + description="Varchar is null", + column=MyTable.varchar, + initial=None, + querystring=MyTable.varchar + "!!!", + expected=None, + ), + OperatorTestCase( + description="Reverse Varchar is null", + column=MyTable.varchar, + initial=None, + querystring="!!!" + MyTable.varchar, + expected=None, + ), + # Integer + OperatorTestCase( + description="Add Integer", + column=MyTable.integer, + initial=1000, + querystring=MyTable.integer + 10, + expected=1010, + ), + OperatorTestCase( + description="Reverse add Integer", + column=MyTable.integer, + initial=1000, + querystring=10 + MyTable.integer, + expected=1010, + ), + OperatorTestCase( + description="Add Integer colums together", + column=MyTable.integer, + initial=1000, + querystring=MyTable.integer + MyTable.integer, + expected=2000, + ), + OperatorTestCase( + description="Subtract Integer", + column=MyTable.integer, + initial=1000, + querystring=MyTable.integer - 10, + expected=990, + ), + OperatorTestCase( + description="Reverse subtract Integer", + column=MyTable.integer, + initial=1000, + querystring=2000 - MyTable.integer, + expected=1000, + ), + OperatorTestCase( + description="Subtract Integer Columns", + column=MyTable.integer, + initial=1000, + querystring=MyTable.integer - MyTable.other_integer, + expected=995, + ), + OperatorTestCase( + description="Add Integer Columns", + column=MyTable.integer, + initial=1000, + querystring=MyTable.integer + MyTable.other_integer, + expected=1005, + ), + OperatorTestCase( + description="Multiply Integer", + column=MyTable.integer, + initial=1000, + querystring=MyTable.integer * 2, + expected=2000, + ), + OperatorTestCase( + description="Reverse multiply Integer", + column=MyTable.integer, + initial=1000, + querystring=2 * MyTable.integer, + expected=2000, + ), + OperatorTestCase( + description="Divide Integer", + column=MyTable.integer, + initial=1000, + querystring=MyTable.integer / 10, + expected=100, + ), + OperatorTestCase( + description="Reverse divide Integer", + column=MyTable.integer, + initial=1000, + querystring=2000 / MyTable.integer, + expected=2, + ), + OperatorTestCase( + description="Integer is null", + column=MyTable.integer, + initial=None, + querystring=MyTable.integer + 1, + expected=None, + ), + OperatorTestCase( + description="Reverse Integer is null", + column=MyTable.integer, + initial=None, + querystring=1 + MyTable.integer, + expected=None, + ), + # Timestamp + OperatorTestCase( + description="Add Timestamp", + column=MyTable.timestamp, + initial=INITIAL_DATETIME, + querystring=MyTable.timestamp + DATETIME_DELTA, + expected=datetime.datetime( + year=2022, + month=1, + day=2, + hour=22, + minute=1, + second=30, + microsecond=1000, + ), + ), + OperatorTestCase( + description="Reverse add Timestamp", + column=MyTable.timestamp, + initial=INITIAL_DATETIME, + querystring=DATETIME_DELTA + MyTable.timestamp, + expected=datetime.datetime( + year=2022, + month=1, + day=2, + hour=22, + minute=1, + second=30, + microsecond=1000, + ), + ), + OperatorTestCase( + description="Subtract Timestamp", + column=MyTable.timestamp, + initial=INITIAL_DATETIME, + querystring=MyTable.timestamp - DATETIME_DELTA, + expected=datetime.datetime( + year=2021, + month=12, + day=31, + hour=19, + minute=58, + second=29, + microsecond=999000, + ), + ), + OperatorTestCase( + description="Timestamp is null", + column=MyTable.timestamp, + initial=None, + querystring=MyTable.timestamp + DATETIME_DELTA, + expected=None, + ), + # Timestamptz + OperatorTestCase( + description="Add Timestamptz", + column=MyTable.timestamptz, + initial=INITIAL_DATETIME, + querystring=MyTable.timestamptz + DATETIME_DELTA, + expected=datetime.datetime( + year=2022, + month=1, + day=2, + hour=22, + minute=1, + second=30, + microsecond=1000, + tzinfo=datetime.timezone.utc, + ), + ), + OperatorTestCase( + description="Reverse add Timestamptz", + column=MyTable.timestamptz, + initial=INITIAL_DATETIME, + querystring=DATETIME_DELTA + MyTable.timestamptz, + expected=datetime.datetime( + year=2022, + month=1, + day=2, + hour=22, + minute=1, + second=30, + microsecond=1000, + tzinfo=datetime.timezone.utc, + ), + ), + OperatorTestCase( + description="Subtract Timestamptz", + column=MyTable.timestamptz, + initial=INITIAL_DATETIME, + querystring=MyTable.timestamptz - DATETIME_DELTA, + expected=datetime.datetime( + year=2021, + month=12, + day=31, + hour=19, + minute=58, + second=29, + microsecond=999000, + tzinfo=datetime.timezone.utc, + ), + ), + OperatorTestCase( + description="Timestamptz is null", + column=MyTable.timestamptz, + initial=None, + querystring=MyTable.timestamptz + DATETIME_DELTA, + expected=None, + ), + # Date + OperatorTestCase( + description="Add Date", + column=MyTable.date, + initial=INITIAL_DATETIME, + querystring=MyTable.date + DATE_DELTA, + expected=datetime.date(year=2022, month=1, day=2), + ), + OperatorTestCase( + description="Reverse add Date", + column=MyTable.date, + initial=INITIAL_DATETIME, + querystring=DATE_DELTA + MyTable.date, + expected=datetime.date(year=2022, month=1, day=2), + ), + OperatorTestCase( + description="Subtract Date", + column=MyTable.date, + initial=INITIAL_DATETIME, + querystring=MyTable.date - DATE_DELTA, + expected=datetime.date(year=2021, month=12, day=31), + ), + OperatorTestCase( + description="Date is null", + column=MyTable.date, + initial=None, + querystring=MyTable.date + DATE_DELTA, + expected=None, + ), + # Interval + OperatorTestCase( + description="Add Interval", + column=MyTable.interval, + initial=INITIAL_INTERVAL, + querystring=MyTable.interval + DATETIME_DELTA, + expected=datetime.timedelta(days=2, seconds=7350, microseconds=1000), + ), + OperatorTestCase( + description="Reverse add Interval", + column=MyTable.interval, + initial=INITIAL_INTERVAL, + querystring=DATETIME_DELTA + MyTable.interval, + expected=datetime.timedelta(days=2, seconds=7350, microseconds=1000), + ), + OperatorTestCase( + description="Subtract Interval", + column=MyTable.interval, + initial=INITIAL_INTERVAL, + querystring=MyTable.interval - DATETIME_DELTA, + expected=datetime.timedelta( + days=-1, seconds=86369, microseconds=999000 + ), + ), + OperatorTestCase( + description="Interval is null", + column=MyTable.interval, + initial=None, + querystring=MyTable.interval + DATETIME_DELTA, + expected=None, + ), +] + + +class TestOperators(TestCase): + def setUp(self): + MyTable.create_table().run_sync() - response = Band.select(Band.popularity).first().run_sync() + def tearDown(self): + MyTable.alter().drop_table().run_sync() - self.assertEqual(response["popularity"], 100) + @engines_skip("cockroach") + def test_operators(self): + for test_case in TEST_CASES: + print(test_case.description) - def test_mul(self): - self.insert_row() + # Create the initial data in the database. + instance = MyTable() + setattr(instance, test_case.column._meta.name, test_case.initial) + instance.save().run_sync() - Band.update({Band.popularity: Band.popularity * 2}).run_sync() + # Apply the update. + MyTable.update( + {test_case.column: test_case.querystring}, force=True + ).run_sync() - response = Band.select(Band.popularity).first().run_sync() + # Make sure the value returned from the database is correct. + new_value = getattr( + MyTable.objects().first().run_sync(), + test_case.column._meta.name, + ) - self.assertEqual(response["popularity"], 2000) + self.assertEqual( + new_value, test_case.expected, msg=test_case.description + ) - def test_rmul(self): - self.insert_row() + # Clean up + MyTable.delete(force=True).run_sync() - Band.update({Band.popularity: 2 * Band.popularity}).run_sync() + @sqlite_only + def test_edge_cases(self): + """ + Some usecases aren't supported by SQLite, and should raise a + ``ValueError``. + """ + with self.assertRaises(ValueError): + # An error should be raised because we can't save at this level + # of resolution - 1 millisecond is the minimum. + MyTable.timestamp + datetime.timedelta( # type: ignore + microseconds=1 + ) - response = Band.select(Band.popularity).first().run_sync() - self.assertEqual(response["popularity"], 2000) +############################################################################### +# Test auto_update - def test_div(self): - self.insert_row() - Band.update({Band.popularity: Band.popularity / 10}).run_sync() +class AutoUpdateTable(Table, tablename="my_table"): + name = Varchar() + modified_on = Timestamp( + auto_update=datetime.datetime.now, null=True, default=None + ) - response = Band.select(Band.popularity).first().run_sync() - self.assertEqual(response["popularity"], 100) +class TestAutoUpdate(TestCase): + def setUp(self): + AutoUpdateTable.create_table().run_sync() - def test_rdiv(self): - self.insert_row() + def tearDown(self): + AutoUpdateTable.alter().drop_table().run_sync() - Band.update({Band.popularity: 1000 / Band.popularity}).run_sync() + def test_save(self): + """ + Make sure the ``save`` method uses ``auto_update`` columns correctly. + """ + row = AutoUpdateTable(name="test") - response = Band.select(Band.popularity).first().run_sync() + # Saving for the first time is an INSERT, so `auto_update` shouldn't + # be triggered. + row.save().run_sync() + self.assertIsNone(row.modified_on) - self.assertEqual(response["popularity"], 1) + # A subsequent save is an UPDATE, so `auto_update` should be triggered. + row.name = "test 2" + row.save().run_sync() + self.assertIsInstance(row.modified_on, datetime.datetime) + # If we save it again, `auto_update` should be applied again. + existing_modified_on = row.modified_on + row.name = "test 3" + row.save().run_sync() + self.assertIsInstance(row.modified_on, datetime.datetime) + self.assertGreater(row.modified_on, existing_modified_on) -class TestVarcharUpdateOperators(DBTestCase): - def test_add(self): - self.insert_row() + def test_update(self): + """ + Make sure the update method uses ``auto_update`` columns correctly. + """ + # Insert a row for us to update + AutoUpdateTable.insert(AutoUpdateTable(name="test")).run_sync() + + data = ( + AutoUpdateTable.select( + AutoUpdateTable.name, AutoUpdateTable.modified_on + ) + .first() + .run_sync() + ) - Band.update({Band.name: Band.name + "!!!"}).run_sync() + assert data is not None - response = Band.select(Band.name).first().run_sync() + self.assertDictEqual( + data, + {"name": "test", "modified_on": None}, + ) - self.assertEqual(response["name"], "Pythonistas!!!") + # Update the row + AutoUpdateTable.update( + {AutoUpdateTable.name: "test 2"}, force=True + ).run_sync() - def test_add_column(self): - self.insert_row() + # Retrieve the row + updated_row = ( + AutoUpdateTable.select( + AutoUpdateTable.name, AutoUpdateTable.modified_on + ) + .first() + .run_sync() + ) + assert updated_row is not None + self.assertIsInstance(updated_row["modified_on"], datetime.datetime) + self.assertEqual(updated_row["name"], "test 2") - Band.update({Band.name: Band.name + Band.name}).run_sync() - response = Band.select(Band.name).first().run_sync() +############################################################################### +# Test update with joins - self.assertEqual(response["name"], "PythonistasPythonistas") - def test_radd(self): - self.insert_row() +class TestUpdateWithJoin(DBTestCase): + def test_join(self): + """ + Make sure updates work when the where clause needs a join. + """ + self.insert_rows() + Band.update({Band.name: "New name"}).where( + Band.manager.name == "Guido" + ).run_sync() - Band.update({Band.name: "!!!" + Band.name}).run_sync() + self.assertEqual( + Band.count().where(Band.name == "New name").run_sync(), 1 + ) - response = Band.select(Band.name).first().run_sync() + def test_multiple_matches(self): + """ + Make sure it works when the join has multiple matching values. + """ + self.insert_rows() - self.assertEqual(response["name"], "!!!Pythonistas") + # Create an additional band with the same manager. + manager = Manager.objects().get(Manager.name == "Guido").run_sync() + band = Band(name="Pythonistas 2", manager=manager) + band.save().run_sync() + Band.update({Band.name: "New name"}).where( + Band.manager.name == "Guido" + ).run_sync() -class TestTextUpdateOperators(DBTestCase): - def setUp(self): - super().setUp() - Poster(content="Join us for this amazing show").save().run_sync() + self.assertEqual( + Band.count().where(Band.name == "New name").run_sync(), 2 + ) - def test_add(self): - Poster.update({Poster.content: Poster.content + "!!!"}).run_sync() + def test_no_matches(self): + """ + Make sure it works when the join has no matching values. + """ + self.insert_rows() - response = Poster.select(Poster.content).first().run_sync() + Band.update({Band.name: "New name"}).where( + Band.manager.name == "Mr Manager" + ).run_sync() self.assertEqual( - response["content"], "Join us for this amazing show!!!" + Band.count().where(Band.name == "New name").run_sync(), 0 ) - def test_add_column(self): - self.insert_row() + def test_and(self): + """ + Make sure it works when combined with other where clauses using AND. + """ + self.insert_rows() - Poster.update( - {Poster.content: Poster.content + Poster.content} - ).run_sync() + # Create an additional band with the same manager, and different + # popularity. + manager = Manager.objects().get(Manager.name == "Guido").run_sync() + band = Band(name="Pythonistas 2", manager=manager, popularity=10000) + band.save().run_sync() - response = Poster.select(Poster.content).first().run_sync() + Band.update({Band.name: "New name"}).where( + Band.manager.name == "Guido", Band.popularity == 10000 + ).run_sync() self.assertEqual( - response["content"], - "Join us for this amazing show" * 2, + Band.count().where(Band.name == "New name").run_sync(), 1 ) - def test_radd(self): - self.insert_row() - - Poster.update({Poster.content: "!!!" + Poster.content}).run_sync() + def test_or(self): + """ + Make sure it works when combined with other where clauses using OR. + """ + self.insert_rows() - response = Poster.select(Poster.content).first().run_sync() + Band.update({Band.name: "New name"}).where( + (Band.manager.name == "Guido") | (Band.manager.name == "Graydon") + ).run_sync() self.assertEqual( - response["content"], "!!!Join us for this amazing show" + Band.count().where(Band.name == "New name").run_sync(), 2 ) diff --git a/tests/table/test_update_self.py b/tests/table/test_update_self.py new file mode 100644 index 000000000..c06afe708 --- /dev/null +++ b/tests/table/test_update_self.py @@ -0,0 +1,27 @@ +from piccolo.testing.test_case import AsyncTableTest +from tests.example_apps.music.tables import Band, Manager + + +class TestUpdateSelf(AsyncTableTest): + + tables = [Band, Manager] + + async def test_update_self(self): + band = Band({Band.name: "Pythonistas", Band.popularity: 1000}) + + # Make sure we get a ValueError if it's not in the database yet. + with self.assertRaises(ValueError): + await band.update_self({Band.popularity: Band.popularity + 1}) + + # Save it, so it's in the database + await band.save() + + # Make sure we can successfully update the object + await band.update_self({Band.popularity: Band.popularity + 1}) + + # Make sure the value was updated on the object + assert band.popularity == 1001 + + # Make sure the value was updated in the database + await band.refresh() + assert band.popularity == 1001 diff --git a/tests/test_main.py b/tests/test_main.py index 0a0367532..745507f69 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,9 +1,14 @@ -from unittest import TestCase +from unittest import IsolatedAsyncioTestCase +from piccolo.apps.migrations.tables import Migration from piccolo.main import main -class TestMain(TestCase): +class TestMain(IsolatedAsyncioTestCase): + + async def asyncTearDown(self): + await Migration.alter().drop_table(if_exists=True) + def test_main(self): # Just make sure it runs without raising any errors. main() diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 000000000..d8ec3d481 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,183 @@ +from unittest import TestCase + +from piccolo.schema import SchemaManager +from piccolo.table import Table +from tests.base import engines_skip + + +class Band(Table, schema="schema_1"): + pass + + +@engines_skip("sqlite") +class TestListTables(TestCase): + def setUp(self): + Band.create_table().run_sync() + + def tearDown(self): + Band.alter().drop_table().run_sync() + + def test_list_tables(self): + """ + Make sure we can list all the tables in a schema. + """ + schema_name = Band._meta.schema + + assert schema_name is not None + table_list = ( + SchemaManager().list_tables(schema_name=schema_name).run_sync() + ) + self.assertListEqual(table_list, [Band._meta.tablename]) + + +@engines_skip("sqlite") +class TestCreateAndDrop(TestCase): + def test_create_and_drop(self): + """ + Make sure a schema can be created, and dropped. + """ + manager = SchemaManager() + + # Make sure schema names with spaces, and clashing with keywords work. + for schema_name in ("test_schema", "test schema", "user"): + manager.create_schema(schema_name=schema_name).run_sync() + + self.assertIn(schema_name, manager.list_schemas().run_sync()) + + manager.drop_schema(schema_name=schema_name).run_sync() + self.assertNotIn(schema_name, manager.list_schemas().run_sync()) + + +@engines_skip("sqlite") +class TestMoveTable(TestCase): + new_schema = "schema_2" + + def setUp(self): + Band.create_table(if_not_exists=True).run_sync() + SchemaManager().create_schema( + self.new_schema, if_not_exists=True + ).run_sync() + + def tearDown(self): + Band.alter().drop_table(if_exists=True).run_sync() + SchemaManager().drop_schema( + self.new_schema, if_exists=True, cascade=True + ).run_sync() + + def test_move_table(self): + """ + Make sure we can move a table to a different schema. + """ + manager = SchemaManager() + + manager.move_table( + table_name=Band._meta.tablename, + new_schema=self.new_schema, + current_schema=Band._meta.schema, + ).run_sync() + + self.assertIn( + Band._meta.tablename, + manager.list_tables(schema_name=self.new_schema).run_sync(), + ) + + self.assertNotIn( + Band._meta.tablename, + manager.list_tables(schema_name="schema_1").run_sync(), + ) + + +@engines_skip("sqlite") +class TestRenameSchema(TestCase): + manager = SchemaManager() + schema_name = "test_schema" + new_schema_name = "test_schema_2" + + def tearDown(self): + for schema_name in (self.schema_name, self.new_schema_name): + self.manager.drop_schema( + schema_name=schema_name, if_exists=True + ).run_sync() + + def test_rename_schema(self): + """ + Make sure we can rename a schema. + """ + self.manager.create_schema( + schema_name=self.schema_name, if_not_exists=True + ).run_sync() + + self.manager.rename_schema( + schema_name=self.schema_name, new_schema_name=self.new_schema_name + ).run_sync() + + self.assertIn( + self.new_schema_name, self.manager.list_schemas().run_sync() + ) + + +@engines_skip("sqlite") +class TestDDL(TestCase): + manager = SchemaManager() + + def test_create_schema(self): + self.assertEqual( + self.manager.create_schema( + schema_name="schema_1", if_not_exists=False + ).ddl, + 'CREATE SCHEMA "schema_1"', + ) + + self.assertEqual( + self.manager.create_schema( + schema_name="schema_1", if_not_exists=True + ).ddl, + 'CREATE SCHEMA IF NOT EXISTS "schema_1"', + ) + + def test_drop_schema(self): + self.assertEqual( + self.manager.drop_schema( + schema_name="schema_1", if_exists=False + ).ddl, + 'DROP SCHEMA "schema_1"', + ) + + self.assertEqual( + self.manager.drop_schema( + schema_name="schema_1", if_exists=True + ).ddl, + 'DROP SCHEMA IF EXISTS "schema_1"', + ) + + self.assertEqual( + self.manager.drop_schema( + schema_name="schema_1", if_exists=True, cascade=True + ).ddl, + 'DROP SCHEMA IF EXISTS "schema_1" CASCADE', + ) + + self.assertEqual( + self.manager.drop_schema( + schema_name="schema_1", if_exists=False, cascade=True + ).ddl, + 'DROP SCHEMA "schema_1" CASCADE', + ) + + def test_move_table(self): + self.assertEqual( + self.manager.move_table( + table_name="band", + new_schema="schema_2", + current_schema="schema_1", + ).ddl, + 'ALTER TABLE "schema_1"."band" SET SCHEMA "schema_2"', + ) + + self.assertEqual( + self.manager.move_table( + table_name="band", + new_schema="schema_2", + ).ddl, + 'ALTER TABLE "band" SET SCHEMA "schema_2"', + ) diff --git a/tests/testing/__init__.py b/tests/testing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/testing/test_model_builder.py b/tests/testing/test_model_builder.py new file mode 100644 index 000000000..b1d07376a --- /dev/null +++ b/tests/testing/test_model_builder.py @@ -0,0 +1,240 @@ +import asyncio +import enum +import json +import unittest + +from piccolo.columns import ( + Array, + Decimal, + ForeignKey, + Integer, + LazyTableReference, + Numeric, + Real, + Timestamp, + Timestamptz, + Varchar, +) +from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync +from piccolo.testing.model_builder import ModelBuilder +from tests.base import engines_skip +from tests.example_apps.music.tables import ( + Band, + Concert, + Manager, + Poster, + RecordingStudio, + Shirt, + Ticket, + Venue, +) + + +class TableWithArrayField(Table): + class Choices(enum.Enum): + a = "a" + b = "b" + + strings = Array(Varchar(30)) + integers = Array(Integer()) + floats = Array(Real()) + choices = Array(Varchar(), choices=Choices) + + +class TableWithDecimal(Table): + numeric = Numeric() + numeric_with_digits = Numeric(digits=(4, 2)) + decimal = Decimal() + decimal_with_digits = Decimal(digits=(4, 2)) + + +class BandWithLazyReference(Table): + manager: ForeignKey["Manager"] = ForeignKey( + references=LazyTableReference( + "Manager", module_path="tests.example_apps.music.tables" + ) + ) + + +class BandWithRecursiveReference(Table): + manager: ForeignKey["Manager"] = ForeignKey("self") + + +TABLES = ( + Manager, + Band, + Poster, + RecordingStudio, + Shirt, + Venue, + Concert, + Ticket, + TableWithArrayField, + TableWithDecimal, + BandWithLazyReference, + BandWithRecursiveReference, +) + + +# Cockroach Bug: Can turn ON when resolved: https://github.com/cockroachdb/cockroach/issues/71908 # noqa: E501 +@engines_skip("cockroach") +class TestModelBuilder(unittest.TestCase): + @classmethod + def setUpClass(cls): + create_db_tables_sync(*TABLES) + + @classmethod + def tearDownClass(cls) -> None: + drop_db_tables_sync(*TABLES) + + def test_async(self): + async def build_model(table_class: type[Table]): + return await ModelBuilder.build(table_class) + + for table_class in TABLES: + asyncio.run(build_model(table_class)) + + def test_sync(self): + for table_class in TABLES: + ModelBuilder.build_sync(table_class) + + def test_choices(self): + shirt = ModelBuilder.build_sync(Shirt) + queried_shirt = ( + Shirt.objects().where(Shirt.id == shirt.id).first().run_sync() + ) + assert queried_shirt is not None + + self.assertIn( + queried_shirt.size, + ["s", "l", "m"], + ) + + def test_array_choices(self): + """ + Make sure that ``ModelBuilder`` generates arrays where each array + element is a valid choice. + """ + instance = ModelBuilder.build_sync(TableWithArrayField) + for value in instance.choices: + # Will raise an exception if the enum value isn't found: + TableWithArrayField.Choices[value] + + def test_datetime(self): + """ + Make sure that ``ModelBuilder`` generates timezone aware datetime + objects for ``Timestamptz`` columns, and timezone naive datetime + objects for ``Timestamp`` columns. + """ + + class Table1(Table): + starts = Timestamptz() + + class Table2(Table): + starts = Timestamp() + + model_1 = ModelBuilder.build_sync(Table1, persist=False) + assert model_1.starts.tzinfo is not None + + model_2 = ModelBuilder.build_sync(Table2, persist=False) + assert model_2.starts.tzinfo is None + + def test_foreign_key(self): + model = ModelBuilder.build_sync(Band, persist=True) + + self.assertTrue( + Manager.exists().where(Manager.id == model.manager).run_sync() + ) + + def test_lazy_foreign_key(self): + model = ModelBuilder.build_sync(BandWithLazyReference, persist=True) + + self.assertTrue( + Manager.exists().where(Manager.id == model.manager).run_sync() + ) + + def test_recursive_foreign_key(self): + """ + Make sure no infinite loops are created with recursive foreign keys. + """ + model = ModelBuilder.build_sync( + BandWithRecursiveReference, persist=True + ) + # It should be set to None, as this foreign key is nullable. + self.assertIsNone(model.manager) + + def test_invalid_column(self): + with self.assertRaises(ValueError): + ModelBuilder.build_sync(Band, defaults={"X": 1}) + + def test_minimal(self): + band = ModelBuilder.build_sync(Band, minimal=True) + + self.assertTrue(Band.exists().where(Band.id == band.id).run_sync()) + + def test_persist_false(self): + band = ModelBuilder.build_sync(Band, persist=False) + + self.assertFalse(Band.exists().where(Band.id == band.id).run_sync()) + + def test_valid_column(self): + manager = ModelBuilder.build_sync( + Manager, defaults={Manager.name: "Guido"} + ) + + queried_manager = ( + Manager.objects() + .where(Manager.id == manager.id) + .first() + .run_sync() + ) + assert queried_manager is not None + + self.assertEqual(queried_manager.name, "Guido") + + def test_valid_column_string(self): + manager = ModelBuilder.build_sync(Manager, defaults={"name": "Guido"}) + + queried_manager = ( + Manager.objects() + .where(Manager.id == manager.id) + .first() + .run_sync() + ) + assert queried_manager is not None + + self.assertEqual(queried_manager.name, "Guido") + + def test_valid_foreign_key(self): + manager = ModelBuilder.build_sync(Manager) + + band = ModelBuilder.build_sync(Band, defaults={Band.manager: manager}) + + self.assertEqual(manager._meta.primary_key, band.manager) + + def test_valid_foreign_key_string(self): + manager = ModelBuilder.build_sync(Manager) + + band = ModelBuilder.build_sync(Band, defaults={"manager": manager}) + + self.assertEqual(manager._meta.primary_key, band.manager) + + def test_json(self): + """ + Make sure the generated JSON can be parsed. + + This is important, because we might have queries like this:: + + >>> await RecordingStudio.select().output(load_json=True) + + """ + studio = ModelBuilder.build_sync(RecordingStudio) + self.assertIsInstance(json.loads(studio.facilities), dict) + self.assertIsInstance(json.loads(studio.facilities_b), dict) + + for facilities in ( + RecordingStudio.select(RecordingStudio.facilities) + .output(load_json=True, as_list=True) + .run_sync() + ): + self.assertIsInstance(facilities, dict) diff --git a/tests/testing/test_random_builder.py b/tests/testing/test_random_builder.py new file mode 100644 index 000000000..da406dc3b --- /dev/null +++ b/tests/testing/test_random_builder.py @@ -0,0 +1,59 @@ +import decimal +import unittest +from enum import Enum + +from piccolo.testing.random_builder import RandomBuilder + + +class TestRandomBuilder(unittest.TestCase): + def test_next_bool(self): + random_bool = RandomBuilder.next_bool() + self.assertIn(random_bool, [True, False]) + + def test_next_bytes(self): + random_bytes = RandomBuilder.next_bytes(length=100) + self.assertEqual(len(random_bytes), 100) + + def test_next_date(self): + random_date = RandomBuilder.next_date() + self.assertGreaterEqual(random_date.year, 2000) + self.assertLessEqual(random_date.year, 2050) + + def test_next_datetime(self): + random_datetime = RandomBuilder.next_datetime() + self.assertGreaterEqual(random_datetime.year, 2000) + self.assertLessEqual(random_datetime.year, 2050) + + def test_next_enum(self): + class Color(Enum): + RED = 1 + BLUE = 2 + + random_enum = RandomBuilder.next_enum(Color) + self.assertIsInstance(random_enum, int) + + def test_next_float(self): + random_float = RandomBuilder.next_float(maximum=1000) + self.assertLessEqual(random_float, 1000) + + def test_next_decimal(self): + random_decimal = RandomBuilder.next_decimal(precision=4, scale=2) + self.assertLessEqual(random_decimal, decimal.Decimal("99.99")) + + def test_next_int(self): + random_int = RandomBuilder.next_int() + self.assertLessEqual(random_int, 2147483647) + + def test_next_str(self): + random_str = RandomBuilder.next_str(length=64) + self.assertLessEqual(len(random_str), 64) + + def test_next_time(self): + RandomBuilder.next_time() + + def test_next_timedelta(self): + random_timedelta = RandomBuilder.next_timedelta() + self.assertLessEqual(random_timedelta.days, 7) + + def test_next_uuid(self): + RandomBuilder.next_uuid() diff --git a/tests/testing/test_test_case.py b/tests/testing/test_test_case.py new file mode 100644 index 000000000..963a3c371 --- /dev/null +++ b/tests/testing/test_test_case.py @@ -0,0 +1,65 @@ +import sys + +import pytest + +from piccolo.engine import engine_finder +from piccolo.testing.test_case import ( + AsyncTableTest, + AsyncTransactionTest, + TableTest, +) +from tests.example_apps.music.tables import Band, Manager + + +class TestTableTest(TableTest): + """ + Make sure the tables are created automatically. + """ + + tables = [Band, Manager] + + async def test_tables_created(self): + self.assertTrue(Band.table_exists().run_sync()) + self.assertTrue(Manager.table_exists().run_sync()) + + +class TestAsyncTableTest(AsyncTableTest): + """ + Make sure the tables are created automatically in async tests. + """ + + tables = [Band, Manager] + + async def test_tables_created(self): + self.assertTrue(await Band.table_exists()) + self.assertTrue(await Manager.table_exists()) + + +@pytest.mark.skipif(sys.version_info <= (3, 11), reason="Python 3.11 required") +class TestAsyncTransaction(AsyncTransactionTest): + """ + Make sure that the test exists within a transaction. + """ + + async def test_transaction_exists(self): + db = engine_finder() + assert db is not None + self.assertTrue(db.transaction_exists()) + + +@pytest.mark.skipif(sys.version_info <= (3, 11), reason="Python 3.11 required") +class TestAsyncTransactionRolledBack(AsyncTransactionTest): + """ + Make sure that the changes get rolled back automatically. + """ + + async def asyncTearDown(self): + await super().asyncTearDown() + + assert Manager.table_exists().run_sync() is False + + async def test_insert_data(self): + await Manager.create_table() + + manager = Manager({Manager.name: "Guido"}) + await manager.save() diff --git a/tests/type_checking.py b/tests/type_checking.py new file mode 100644 index 000000000..47cf6944f --- /dev/null +++ b/tests/type_checking.py @@ -0,0 +1,131 @@ +""" +Making sure the types are inferred correctly by MyPy. + +Note: We need type annotations on the function, otherwise MyPy treats every +type inside the function as Any. +""" + +from typing import TYPE_CHECKING, Any, Optional + +from typing_extensions import assert_type + +from piccolo.columns import ForeignKey, Varchar +from piccolo.testing.model_builder import ModelBuilder +from piccolo.utils.sync import run_sync + +from .example_apps.music.tables import Band, Concert, Manager + +if TYPE_CHECKING: + + async def objects() -> None: + query = Band.objects() + assert_type(await query, list[Band]) + assert_type(await query.run(), list[Band]) + assert_type(query.run_sync(), list[Band]) + + async def objects_first() -> None: + query = Band.objects().first() + assert_type(await query, Optional[Band]) + assert_type(await query.run(), Optional[Band]) + assert_type(query.run_sync(), Optional[Band]) + + async def get() -> None: + query = Band.objects().get(Band.name == "Pythonistas") + assert_type(await query, Optional[Band]) + assert_type(await query.run(), Optional[Band]) + assert_type(query.run_sync(), Optional[Band]) + + async def foreign_key_reference() -> None: + assert_type(Band.manager, ForeignKey[Manager]) + + async def foreign_key_traversal() -> None: + # Single level + assert_type(Band.manager._.name, Varchar) + # Multi level + assert_type(Concert.band_1._.manager._.name, Varchar) + + async def get_related() -> None: + band = await Band.objects().get(Band.name == "Pythonistas") + assert band is not None + manager = await band.get_related(Band.manager) + assert_type(manager, Optional[Manager]) + + async def get_related_multiple_levels() -> None: + concert = await Concert.objects().first() + assert concert is not None + manager = await concert.get_related(Concert.band_1._.manager) + assert_type(manager, Optional[Manager]) + + async def get_or_create() -> None: + query = Band.objects().get_or_create(Band.name == "Pythonistas") + assert_type(await query, Band) + assert_type(await query.run(), Band) + assert_type(query.run_sync(), Band) + + async def select() -> None: + query = Band.select() + assert_type(await query, list[dict[str, Any]]) + assert_type(await query.run(), list[dict[str, Any]]) + assert_type(query.run_sync(), list[dict[str, Any]]) + + async def select_first() -> None: + query = Band.select().first() + assert_type(await query, Optional[dict[str, Any]]) + assert_type(await query.run(), Optional[dict[str, Any]]) + assert_type(query.run_sync(), Optional[dict[str, Any]]) + + async def select_list() -> None: + query = Band.select(Band.name).output(as_list=True) + assert_type(await query, list) + assert_type(await query.run(), list) + assert_type(query.run_sync(), list) + # The next step would be to detect that it's list[str], but might not + # be possible. + + async def select_as_json() -> None: + query = Band.select(Band.name).output(as_json=True) + assert_type(await query, str) + assert_type(await query.run(), str) + assert_type(query.run_sync(), str) + + async def exists() -> None: + query = Band.exists() + assert_type(await query, bool) + assert_type(await query.run(), bool) + assert_type(query.run_sync(), bool) + + async def table_exists() -> None: + query = Band.table_exists() + assert_type(await query, bool) + assert_type(await query.run(), bool) + assert_type(query.run_sync(), bool) + + async def from_dict() -> None: + assert_type(Band.from_dict(data={}), Band) + + async def update() -> None: + query = Band.update() + assert_type(await query, list[Any]) + assert_type(await query.run(), list[Any]) + assert_type(query.run_sync(), list[Any]) + + async def insert() -> None: + # This is correct: + Band.insert(Band()) + # This is an error: + Band.insert(Manager()) # type: ignore + + async def model_builder() -> None: + assert_type(await ModelBuilder.build(Band), Band) + assert_type(ModelBuilder.build_sync(Band), Band) + + def run_sync_return_type() -> None: + """ + Make sure `run_sync` returns the same type as the coroutine which is + passed in. + """ + + async def my_func() -> str: + return "hello" + + assert_type(run_sync(my_func()), str) diff --git a/tests/utils/test_dictionary.py b/tests/utils/test_dictionary.py new file mode 100644 index 000000000..5817d007b --- /dev/null +++ b/tests/utils/test_dictionary.py @@ -0,0 +1,56 @@ +from unittest import TestCase + +from piccolo.utils.dictionary import make_nested + + +class TestMakeNested(TestCase): + def test_nesting(self): + response = make_nested( + { + "id": 1, + "name": "Pythonistas", + "manager.id": 1, + "manager.name": "Guido", + "manager.car.colour": "green", + } + ) + self.assertEqual( + response, + { + "id": 1, + "name": "Pythonistas", + "manager": { + "id": 1, + "name": "Guido", + "car": {"colour": "green"}, + }, + }, + ) + + def test_name_clash(self): + """ + In this example, `manager` and `manager.*` could potentially clash. + Nesting should take precedence. + """ + response = make_nested( + { + "id": 1, + "name": "Pythonistas", + "manager": 1, + "manager.id": 1, + "manager.name": "Guido", + "manager.car.colour": "green", + } + ) + self.assertEqual( + response, + { + "id": 1, + "name": "Pythonistas", + "manager": { + "id": 1, + "name": "Guido", + "car": {"colour": "green"}, + }, + }, + ) diff --git a/tests/utils/test_lazy_loader.py b/tests/utils/test_lazy_loader.py index 5be0d69ea..32a6be15b 100644 --- a/tests/utils/test_lazy_loader.py +++ b/tests/utils/test_lazy_loader.py @@ -1,14 +1,14 @@ from unittest import TestCase, mock from piccolo.utils.lazy_loader import LazyLoader -from tests.base import postgres_only, sqlite_only +from tests.base import engines_only, sqlite_only class TestLazyLoader(TestCase): def test_lazy_loading_database_driver(self): _ = LazyLoader("asyncpg", globals(), "asyncpg") - @postgres_only + @engines_only("postgres", "cockroach") def test_lazy_loader_asyncpg_exception(self): lazy_loader = LazyLoader("asyncpg", globals(), "asyncpg.connect") diff --git a/tests/utils/test_list.py b/tests/utils/test_list.py new file mode 100644 index 000000000..c21cec6b4 --- /dev/null +++ b/tests/utils/test_list.py @@ -0,0 +1,28 @@ +import string +from unittest import TestCase + +from piccolo.utils.list import batch, flatten + + +class TestFlatten(TestCase): + def test_flatten(self): + self.assertListEqual(flatten(["a", ["b", "c"]]), ["a", "b", "c"]) + + +class TestBatch(TestCase): + def test_batch(self): + self.assertListEqual( + batch([i for i in string.ascii_lowercase], chunk_size=5), + [ + ["a", "b", "c", "d", "e"], + ["f", "g", "h", "i", "j"], + ["k", "l", "m", "n", "o"], + ["p", "q", "r", "s", "t"], + ["u", "v", "w", "x", "y"], + ["z"], + ], + ) + + def test_zero(self): + with self.assertRaises(ValueError): + batch([1, 2, 3], chunk_size=0) diff --git a/tests/utils/test_pydantic.py b/tests/utils/test_pydantic.py new file mode 100644 index 000000000..ebfd78843 --- /dev/null +++ b/tests/utils/test_pydantic.py @@ -0,0 +1,963 @@ +import decimal +from typing import Optional, cast +from unittest import TestCase + +import pydantic +import pydantic_core +import pytest +from pydantic import ValidationError + +from piccolo.columns import ( + JSON, + JSONB, + UUID, + Array, + Email, + Integer, + Numeric, + Secret, + Text, + Time, + Timestamp, + Timestamptz, + Varchar, +) +from piccolo.columns.column_types import ForeignKey +from piccolo.table import Table +from piccolo.utils.pydantic import create_pydantic_model + + +class TestVarcharColumn(TestCase): + def test_varchar_length(self): + class Manager(Table): + name = Varchar(length=10) + + pydantic_model = create_pydantic_model(table=Manager) + + with self.assertRaises(ValidationError): + pydantic_model(name="This is a really long name") + + pydantic_model(name="short name") + + +class TestEmailColumn(TestCase): + def test_email(self): + class Manager(Table): + email = Email() + + pydantic_model = create_pydantic_model(table=Manager) + + self.assertEqual( + pydantic_model.model_json_schema()["properties"]["email"]["anyOf"][ + 0 + ]["format"], + "email", + ) + + with self.assertRaises(ValidationError): + pydantic_model(email="not a valid email") + + # Shouldn't raise an exception: + pydantic_model(email="test@gmail.com") + + +class TestNumericColumn(TestCase): + """ + Numeric and Decimal are the same - so we'll just test Numeric. + """ + + def test_numeric_digits(self): + class Band(Table): + royalties = Numeric(digits=(5, 1)) + + pydantic_model = create_pydantic_model(table=Band) + + with self.assertRaises(ValidationError): + # This should fail as there are too much numbers after the decimal + # point + pydantic_model(royalties=decimal.Decimal("1.11")) + + with self.assertRaises(ValidationError): + # This should fail as there are too much numbers in total + pydantic_model(royalties=decimal.Decimal("11111.1")) + + pydantic_model(royalties=decimal.Decimal("1.0")) + + def test_numeric_without_digits(self): + class Band(Table): + royalties = Numeric() + + try: + create_pydantic_model(table=Band) + except TypeError: + self.fail( + "Creating numeric field without" + " digits failed in pydantic model." + ) + else: + self.assertTrue(True) + + +class TestSecretColumn(TestCase): + def test_secret_param(self): + class TopSecret(Table): + confidential = Secret() + + pydantic_model = create_pydantic_model(table=TopSecret) + self.assertEqual( + pydantic_model.model_json_schema()["properties"]["confidential"][ + "extra" + ]["secret"], + True, + ) + + +class TestArrayColumn(TestCase): + def test_array_param(self): + class Band(Table): + members = Array(base_column=Varchar(length=16)) + + pydantic_model = create_pydantic_model(table=Band) + + self.assertEqual( + pydantic_model.model_json_schema()["properties"]["members"][ + "anyOf" + ][0]["items"]["type"], + "string", + ) + + def test_multidimensional_array(self): + """ + Make sure that multidimensional arrays have the correct type. + """ + + class Band(Table): + members = Array(Array(Varchar(length=255)), required=True) + + pydantic_model = create_pydantic_model(table=Band) + + self.assertEqual( + pydantic_model.model_fields["members"].annotation, + list[list[pydantic.constr(max_length=255)]], + ) + + # Should not raise a validation error: + pydantic_model( + members=[ + ["Alice", "Bob", "Francis"], + ["Alan", "Georgia", "Sue"], + ] + ) + + with self.assertRaises(ValueError): + pydantic_model(members=["Bob"]) + + +class TestForeignKeyColumn(TestCase): + def test_target_column(self): + """ + Make sure the `target_column` is correctly set in the Pydantic schema + for `ForeignKey` columns. + """ + + class Manager(Table): + name = Varchar(unique=True) + + class BandA(Table): + manager = ForeignKey(Manager, target_column=Manager.name) + + class BandB(Table): + manager = ForeignKey(Manager, target_column="name") + + class BandC(Table): + manager = ForeignKey(Manager) + + self.assertEqual( + create_pydantic_model(table=BandA).model_json_schema()[ + "properties" + ]["manager"]["extra"]["foreign_key"]["target_column"], + "name", + ) + + self.assertEqual( + create_pydantic_model(table=BandB).model_json_schema()[ + "properties" + ]["manager"]["extra"]["foreign_key"]["target_column"], + "name", + ) + + self.assertEqual( + create_pydantic_model(table=BandC).model_json_schema()[ + "properties" + ]["manager"]["extra"]["foreign_key"]["target_column"], + "id", + ) + + +class TestTextColumn(TestCase): + def test_text_widget(self): + """ + Make sure that we indicate that `Text` columns require a special widget + in Piccolo Admin. + """ + + class Band(Table): + bio = Text() + + pydantic_model = create_pydantic_model(table=Band) + + self.assertEqual( + pydantic_model.model_json_schema()["properties"]["bio"]["extra"][ + "widget" + ], + "text-area", + ) + + +class TestTimeColumn(TestCase): + def test_time_format(self): + class Concert(Table): + start_time = Time() + + pydantic_model = create_pydantic_model(table=Concert) + + self.assertEqual( + pydantic_model.model_json_schema()["properties"]["start_time"][ + "anyOf" + ][0]["format"], + "time", + ) + + +class TestTimestamptzColumn(TestCase): + def test_timestamptz_widget(self): + """ + Make sure that we indicate that `Timestamptz` columns require a special + widget in Piccolo Admin. + """ + + class Concert(Table): + starts_on_1 = Timestamptz() + starts_on_2 = Timestamp() + + pydantic_model = create_pydantic_model(table=Concert) + + properties = pydantic_model.model_json_schema()["properties"] + + self.assertEqual( + properties["starts_on_1"]["extra"]["widget"], + "timestamptz", + ) + + self.assertIsNone(properties["starts_on_2"]["extra"].get("widget")) + + +class TestUUIDColumn(TestCase): + class Ticket(Table): + code = UUID() + + def setUp(self): + self.Ticket.create_table().run_sync() + + def tearDown(self): + self.Ticket.alter().drop_table().run_sync() + + def test_uuid_format(self): + class Ticket(Table): + code = UUID() + + pydantic_model = create_pydantic_model(table=Ticket) + + ticket = Ticket() + ticket.save().run_sync() + + # We'll also fetch it from the DB in case the database adapter's UUID + # is used. + ticket_from_db = Ticket.objects().first().run_sync() + assert ticket_from_db is not None + + for ticket_ in (ticket, ticket_from_db): + json = pydantic_model(**ticket_.to_dict()).model_dump_json() + self.assertEqual(json, '{"code":"' + str(ticket_.code) + '"}') + + self.assertEqual( + pydantic_model.model_json_schema()["properties"]["code"]["anyOf"][ + 0 + ]["format"], + "uuid", + ) + + +class TestColumnHelpText(TestCase): + """ + Make sure that columns with `help_text` attribute defined have the + relevant text appear in the schema. + """ + + def test_column_help_text_present(self): + help_text = "In millions of US dollars." + + class Band(Table): + royalties = Numeric(digits=(5, 1), help_text=help_text) + + pydantic_model = create_pydantic_model(table=Band) + + self.assertEqual( + pydantic_model.model_json_schema()["properties"]["royalties"][ + "extra" + ]["help_text"], + help_text, + ) + + +class TestTableHelpText(TestCase): + """ + Make sure that tables with `help_text` attribute defined have the + relevant text appear in the schema. + """ + + def test_table_help_text_present(self): + help_text = "Bands playing concerts." + + class Band(Table, help_text=help_text): + name = Varchar() + + pydantic_model = create_pydantic_model(table=Band) + + self.assertEqual( + pydantic_model.model_json_schema()["extra"]["help_text"], + help_text, + ) + + +class TestUniqueColumn(TestCase): + def test_unique_column_true(self): + class Manager(Table): + name = Varchar(unique=True) + + pydantic_model = create_pydantic_model(table=Manager) + + self.assertEqual( + pydantic_model.model_json_schema()["properties"]["name"]["extra"][ + "unique" + ], + True, + ) + + def test_unique_column_false(self): + class Manager(Table): + name = Varchar() + + pydantic_model = create_pydantic_model(table=Manager) + + self.assertEqual( + pydantic_model.model_json_schema()["properties"]["name"]["extra"][ + "unique" + ], + False, + ) + + +class TestJSONColumn(TestCase): + def test_default(self): + class Studio(Table): + facilities = JSON() + facilities_b = JSONB() + + pydantic_model = create_pydantic_model(table=Studio) + + json_string = '{"guitar_amps": 6}' + + model_instance = pydantic_model( + facilities=json_string, facilities_b=json_string + ) + self.assertEqual( + model_instance.facilities, + json_string, + ) + self.assertEqual( + model_instance.facilities_b, + json_string, + ) + + def test_deserialize_json(self): + class Studio(Table): + facilities = JSON() + facilities_b = JSONB() + + pydantic_model = create_pydantic_model( + table=Studio, deserialize_json=True + ) + + json_string = '{"guitar_amps": 6}' + output = {"guitar_amps": 6} + + model_instance = pydantic_model( + facilities=json_string, facilities_b=json_string + ) + self.assertEqual( + model_instance.facilities, + output, + ) + self.assertEqual( + model_instance.facilities_b, + output, + ) + + def test_validation(self): + class Studio(Table): + facilities = JSON() + facilities_b = JSONB() + + for deserialize_json in (True, False): + pydantic_model = create_pydantic_model( + table=Studio, deserialize_json=deserialize_json + ) + + json_string = "error" + + with self.assertRaises(pydantic.ValidationError): + pydantic_model( + facilities=json_string, facilities_b=json_string + ) + + def test_json_widget(self): + """ + Make sure that we indicate that `JSON` / `JSONB` columns require a + special widget in Piccolo Admin. + """ + + class Studio(Table): + facilities = JSON() + + pydantic_model = create_pydantic_model(table=Studio) + + self.assertEqual( + pydantic_model.model_json_schema()["properties"]["facilities"][ + "extra" + ]["widget"], + "json", + ) + + def test_null_value(self): + class Studio(Table): + facilities = JSON(null=True) + facilities_b = JSONB(null=True) + + pydantic_model = create_pydantic_model(table=Studio) + movie = pydantic_model(facilities=None, facilities_b=None) + + self.assertIsNone(movie.facilities) + self.assertIsNone(movie.facilities_b) + + +class TestExcludeColumns(TestCase): + def test_all(self): + class Band(Table): + name = Varchar() + bio = Text() + + pydantic_model = create_pydantic_model(Band, exclude_columns=()) + + properties = pydantic_model.model_json_schema()["properties"] + self.assertIsInstance(properties["name"], dict) + self.assertIsInstance(properties["bio"], dict) + + def test_exclude(self): + class Band(Table): + name = Varchar() + album = Varchar() + + pydantic_model = create_pydantic_model( + Band, + exclude_columns=(Band.name,), + ) + + properties = pydantic_model.model_json_schema()["properties"] + self.assertIsInstance(properties.get("album"), dict) + self.assertIsNone(properties.get("dict")) + + def test_exclude_all_manually(self): + class Band(Table): + name = Varchar() + album = Varchar() + + pydantic_model = create_pydantic_model( + Band, + exclude_columns=(Band.name, Band.album), + ) + + self.assertEqual(pydantic_model.model_json_schema()["properties"], {}) + + def test_exclude_all_meta(self): + class Band(Table): + name = Varchar() + album = Varchar() + + pydantic_model = create_pydantic_model( + Band, + exclude_columns=tuple(Band._meta.columns), + ) + + self.assertEqual(pydantic_model.model_json_schema()["properties"], {}) + + def test_invalid_column_str(self): + class Band(Table): + name = Varchar() + album = Varchar() + + with self.assertRaises(ValueError): + create_pydantic_model( + Band, + exclude_columns=("album",), + ) + + def test_invalid_column_different_table(self): + class Band(Table): + name = Varchar() + album = Varchar() + + class Band2(Table): + photo = Varchar() + + with self.assertRaises(ValueError): + create_pydantic_model(Band, exclude_columns=(Band2.photo,)) + + def test_invalid_column_different_table_same_type(self): + class Band(Table): + name = Varchar() + album = Varchar() + + class Band2(Table): + name = Varchar() + + with self.assertRaises(ValueError): + create_pydantic_model(Band, exclude_columns=(Band2.name,)) + + def test_exclude_nested(self): + class Manager(Table): + name = Varchar() + phone_number = Integer() + + class Band(Table): + name = Varchar() + manager = ForeignKey(Manager) + popularity = Integer() + + pydantic_model = create_pydantic_model( + table=Band, + exclude_columns=( + Band.popularity, + Band.manager.phone_number, + ), + nested=(Band.manager,), + ) + + model_instance = pydantic_model( + name="Pythonistas", manager={"name": "Guido"} + ) + self.assertEqual( + model_instance.model_dump(), + {"name": "Pythonistas", "manager": {"name": "Guido"}}, + ) + + +class TestIncludeColumns(TestCase): + def test_include(self): + class Band(Table): + name = Varchar() + popularity = Integer() + + pydantic_model = create_pydantic_model( + Band, + include_columns=(Band.name,), + ) + + properties = pydantic_model.model_json_schema()["properties"] + self.assertIsInstance(properties.get("name"), dict) + self.assertIsNone(properties.get("popularity")) + + def test_include_exclude_error(self): + """ + An exception should be raised if both `include_columns` and + `exclude_columns` are provided. + """ + + class Band(Table): + name = Varchar() + popularity = Integer() + + with self.assertRaises(ValueError): + create_pydantic_model( + Band, + exclude_columns=(Band.name,), + include_columns=(Band.name,), + ) + + def test_nested(self): + """ + Make sure that columns on related tables work. + """ + + class Manager(Table): + name = Varchar() + phone_number = Integer() + + class Band(Table): + name = Varchar() + manager = ForeignKey(Manager) + popularity = Integer() + + pydantic_model = create_pydantic_model( + table=Band, + include_columns=( + Band.name, + Band.manager.name, + ), + nested=(Band.manager,), + ) + + model_instance = pydantic_model( + name="Pythonistas", manager={"name": "Guido"} + ) + self.assertEqual( + model_instance.model_dump(), + {"name": "Pythonistas", "manager": {"name": "Guido"}}, + ) + + +class TestNestedModel(TestCase): + def test_true(self): + """ + Make sure all foreign key columns are converted to nested models, when + `nested=True`. + """ + + class Country(Table): + name = Varchar(length=10) + + class Manager(Table): + name = Varchar(length=10) + country = ForeignKey(Country) + + class Band(Table): + name = Varchar(length=10) + manager = ForeignKey(Manager) + + BandModel = create_pydantic_model(table=Band, nested=True) + + ####################################################################### + + ManagerModel = cast( + type[pydantic.BaseModel], + BandModel.model_fields["manager"].annotation, + ) + self.assertTrue(issubclass(ManagerModel, pydantic.BaseModel)) + self.assertEqual( + [i for i in ManagerModel.model_fields.keys()], ["name", "country"] + ) + + ####################################################################### + + CountryModel = cast( + type[pydantic.BaseModel], + ManagerModel.model_fields["country"].annotation, + ) + self.assertTrue(issubclass(CountryModel, pydantic.BaseModel)) + self.assertEqual( + [i for i in CountryModel.model_fields.keys()], ["name"] + ) + + def test_tuple(self): + """ + Make sure only the specified foreign key columns are converted to + nested models. + """ + + class Country(Table): + name = Varchar() + + class Manager(Table): + name = Varchar() + country = ForeignKey(Country) + + class Band(Table): + name = Varchar() + manager = ForeignKey(Manager) + assistant_manager = ForeignKey(Manager) + + class Venue(Table): + name = Varchar() + + class Concert(Table): + band_1 = ForeignKey(Band) + band_2 = ForeignKey(Band) + venue = ForeignKey(Venue) + + ####################################################################### + # Test one level deep + + BandModel = create_pydantic_model(table=Band, nested=(Band.manager,)) + + ManagerModel = cast( + type[pydantic.BaseModel], + BandModel.model_fields["manager"].annotation, + ) + self.assertTrue(issubclass(ManagerModel, pydantic.BaseModel)) + self.assertEqual( + [i for i in ManagerModel.model_fields.keys()], ["name", "country"] + ) + self.assertEqual(ManagerModel.__qualname__, "Band.manager") + + AssistantManagerType = BandModel.model_fields[ + "assistant_manager" + ].annotation + self.assertIs(AssistantManagerType, Optional[int]) + + ####################################################################### + # Test two levels deep + + BandModel = create_pydantic_model( + table=Band, nested=(Band.manager._.country,) + ) + + ManagerModel = cast( + type[pydantic.BaseModel], + BandModel.model_fields["manager"].annotation, + ) + self.assertTrue(issubclass(ManagerModel, pydantic.BaseModel)) + self.assertEqual( + [i for i in ManagerModel.model_fields.keys()], ["name", "country"] + ) + self.assertEqual(ManagerModel.__qualname__, "Band.manager") + + AssistantManagerType = cast( + type[pydantic.BaseModel], + BandModel.model_fields["assistant_manager"].annotation, + ) + self.assertIs(AssistantManagerType, Optional[int]) + + CountryModel = cast( + type[pydantic.BaseModel], + ManagerModel.model_fields["country"].annotation, + ) + self.assertTrue(issubclass(CountryModel, pydantic.BaseModel)) + self.assertEqual( + [i for i in CountryModel.model_fields.keys()], ["name"] + ) + self.assertEqual(CountryModel.__qualname__, "Band.manager.country") + + ####################################################################### + # Test three levels deep + + ConcertModel = create_pydantic_model( + Concert, nested=(Concert.band_1._.manager,) + ) + + VenueModel = ConcertModel.model_fields["venue"].annotation + self.assertIs(VenueModel, Optional[int]) + + BandModel = cast( + type[pydantic.BaseModel], + ConcertModel.model_fields["band_1"].annotation, + ) + self.assertTrue(issubclass(BandModel, pydantic.BaseModel)) + self.assertEqual( + [i for i in BandModel.model_fields.keys()], + ["name", "manager", "assistant_manager"], + ) + self.assertEqual(BandModel.__qualname__, "Concert.band_1") + + ManagerModel = cast( + type[pydantic.BaseModel], + BandModel.model_fields["manager"].annotation, + ) + self.assertTrue(issubclass(ManagerModel, pydantic.BaseModel)) + self.assertEqual( + [i for i in ManagerModel.model_fields.keys()], + ["name", "country"], + ) + self.assertEqual(ManagerModel.__qualname__, "Concert.band_1.manager") + + AssistantManagerType = BandModel.model_fields[ + "assistant_manager" + ].annotation + self.assertIs(AssistantManagerType, Optional[int]) + + CountryModel = ManagerModel.model_fields["country"].annotation + self.assertIs(CountryModel, Optional[int]) + + ####################################################################### + # Test with `model_name` arg + + MyConcertModel = create_pydantic_model( + Concert, + nested=(Concert.band_1._.manager,), + model_name="MyConcertModel", + ) + + BandModel = cast( + type[pydantic.BaseModel], + MyConcertModel.model_fields["band_1"].annotation, + ) + self.assertEqual(BandModel.__qualname__, "MyConcertModel.band_1") + + ManagerModel = BandModel.model_fields["manager"].annotation + self.assertEqual( + ManagerModel.__qualname__, "MyConcertModel.band_1.manager" + ) + + def test_cascaded_args(self) -> None: + """ + Make sure that arguments passed to ``create_pydantic_model`` are + cascaded to nested models. + """ + + class Country(Table): + name = Varchar(length=10) + + class Manager(Table): + name = Varchar(length=10) + country = ForeignKey(Country) + + class Band(Table): + name = Varchar(length=10) + manager = ForeignKey(Manager) + + BandModel = create_pydantic_model( + table=Band, nested=True, include_default_columns=True + ) + + ManagerModel = cast( + type[pydantic.BaseModel], + BandModel.model_fields["manager"].annotation, + ) + self.assertTrue(issubclass(ManagerModel, pydantic.BaseModel)) + self.assertEqual( + [i for i in ManagerModel.model_fields.keys()], + ["id", "name", "country"], + ) + + CountryModel = cast( + type[pydantic.BaseModel], + ManagerModel.model_fields["country"].annotation, + ) + self.assertTrue(issubclass(CountryModel, pydantic.BaseModel)) + self.assertEqual( + [i for i in CountryModel.model_fields.keys()], ["id", "name"] + ) + + +class TestRecursionDepth(TestCase): + def test_max(self): + class Country(Table): + name = Varchar() + + class Manager(Table): + name = Varchar() + country = ForeignKey(Country) + + class Band(Table): + name = Varchar() + manager = ForeignKey(Manager) + assistant_manager = ForeignKey(Manager) + + class Venue(Table): + name = Varchar() + + class Concert(Table): + band = ForeignKey(Band) + venue = ForeignKey(Venue) + + ConcertModel = create_pydantic_model( + table=Concert, nested=True, max_recursion_depth=2 + ) + + VenueModel = cast( + type[pydantic.BaseModel], + ConcertModel.model_fields["venue"].annotation, + ) + self.assertTrue(issubclass(VenueModel, pydantic.BaseModel)) + + BandModel = cast( + type[pydantic.BaseModel], + ConcertModel.model_fields["band"].annotation, + ) + self.assertTrue(issubclass(BandModel, pydantic.BaseModel)) + + ManagerModel = cast( + type[pydantic.BaseModel], + BandModel.model_fields["manager"].annotation, + ) + self.assertTrue(issubclass(ManagerModel, pydantic.BaseModel)) + + # We should have hit the recursion depth: + CountryModel = ManagerModel.model_fields["country"].annotation + self.assertIs(CountryModel, Optional[int]) + + +class TestDBColumnName(TestCase): + def test_db_column_name(self): + """ + Make sure that the Pydantic model has an alias if ``db_column_name`` + is specified for a column. + """ + + class Band(Table): + name = Varchar(db_column_name="regrettable_column_name") + + BandModel = create_pydantic_model(table=Band) + + model = BandModel(regrettable_column_name="test") + + self.assertEqual(model.name, "test") + + +class TestJSONSchemaExtra(TestCase): + def test_json_schema_extra(self): + """ + Make sure that the ``json_schema_extra`` arguments are reflected in + Pydantic model's schema. + """ + + class Band(Table): + name = Varchar() + + model = create_pydantic_model( + Band, json_schema_extra={"extra": {"visible_columns": ("name",)}} + ) + self.assertEqual( + model.model_json_schema()["extra"]["visible_columns"], ("name",) + ) + + +class TestPydanticExtraFields(TestCase): + def test_pydantic_extra_fields(self) -> None: + """ + Make sure that the value of ``extra`` in the config class + is correctly propagated to the generated model. + """ + + class Band(Table): + name = Varchar() + + config: pydantic.config.ConfigDict = {"extra": "forbid"} + model = create_pydantic_model(Band, pydantic_config=config) + + self.assertEqual(model.model_config.get("extra"), "forbid") + + def test_pydantic_invalid_extra_fields(self) -> None: + """ + Make sure that invalid values for ``extra`` in the config class + are rejected. + """ + + class Band(Table): + name = Varchar() + + config: pydantic.config.ConfigDict = { + "extra": "foobar" # type: ignore + } + + with pytest.raises(pydantic_core._pydantic_core.SchemaError): + create_pydantic_model(Band, pydantic_config=config) diff --git a/tests/utils/test_sql_values.py b/tests/utils/test_sql_values.py index 299cf96a0..a5bc0416d 100644 --- a/tests/utils/test_sql_values.py +++ b/tests/utils/test_sql_values.py @@ -1,7 +1,10 @@ +import time from enum import Enum from unittest import TestCase -from piccolo.columns.column_types import JSON, JSONB, Integer, Varchar +import pytest + +from piccolo.columns.column_types import JSON, JSONB, Array, Integer, Varchar from piccolo.table import Table from piccolo.utils.sql_values import convert_to_sql_value @@ -48,10 +51,10 @@ def test_convert_enum(self): """ class Colour(Enum): - red = "red" + red = "r" self.assertEqual( - convert_to_sql_value(value=Colour.red, column=Varchar()), "red" + convert_to_sql_value(value=Colour.red, column=Varchar()), "r" ) def test_other(self): @@ -62,3 +65,37 @@ def test_other(self): convert_to_sql_value(value=1, column=Integer()), 1, ) + + def test_convert_enum_list(self): + """ + It's possible to have a list of enums when using ``Array`` columns. + """ + + class Colour(Enum): + red = "r" + green = "g" + blue = "b" + + self.assertEqual( + convert_to_sql_value( + value=[Colour.red, Colour.green, Colour.blue], + column=Array(Varchar()), + ), + ["r", "g", "b"], + ) + + @pytest.mark.speed + def test_convert_large_list(self): + """ + Large lists are problematic. We need to check each value in the list, + but as efficiently as possible. + """ + start = time.time() + + convert_to_sql_value( + value=[i for i in range(1000)], + column=Array(Varchar()), + ) + + duration = time.time() - start + print(duration) diff --git a/tests/utils/test_table_reflection.py b/tests/utils/test_table_reflection.py new file mode 100644 index 000000000..88cba6910 --- /dev/null +++ b/tests/utils/test_table_reflection.py @@ -0,0 +1,92 @@ +from unittest import TestCase + +from piccolo.columns import Varchar +from piccolo.table import Table +from piccolo.table_reflection import TableStorage +from piccolo.utils.sync import run_sync +from tests.base import engines_only +from tests.example_apps.music.tables import Band, Manager + + +@engines_only("postgres", "cockroach") +class TestTableStorage(TestCase): + def setUp(self) -> None: + self.table_storage = TableStorage() + for table_class in (Manager, Band): + table_class.create_table().run_sync() + + def tearDown(self): + self.table_storage.clear() + for table_class in (Band, Manager): + table_class.alter().drop_table(if_exists=True).run_sync() + + def _compare_table_columns( + self, table_1: type[Table], table_2: type[Table] + ): + """ + Make sure that for each column in table_1, there is a corresponding + column in table_2 of the same type. + """ + column_names = [ + column._meta.name for column in table_1._meta.non_default_columns + ] + for column_name in column_names: + col_1 = table_1._meta.get_column_by_name(column_name) + col_2 = table_2._meta.get_column_by_name(column_name) + + # Make sure they're the same type + self.assertEqual(type(col_1), type(col_2)) + + # Make sure they're both nullable or not + self.assertEqual(col_1._meta.null, col_2._meta.null) + + # Make sure the max length is the same + if isinstance(col_1, Varchar) and isinstance(col_2, Varchar): + self.assertEqual(col_1.length, col_2.length) + + # Make sure the unique constraint is the same + self.assertEqual(col_1._meta.unique, col_2._meta.unique) + + def test_reflect_all_tables(self): + run_sync(self.table_storage.reflect()) + reflected_tables = self.table_storage.tables + self.assertEqual(len(reflected_tables), 2) + for table_class in (Manager, Band): + self._compare_table_columns( + reflected_tables[table_class._meta.tablename], table_class + ) + + def test_reflect_with_include(self): + run_sync(self.table_storage.reflect(include=["manager"])) + reflected_tables = self.table_storage.tables + self.assertEqual(len(reflected_tables), 1) + self._compare_table_columns(reflected_tables["manager"], Manager) + + def test_reflect_with_exclude(self): + run_sync(self.table_storage.reflect(exclude=["band"])) + reflected_tables = self.table_storage.tables + self.assertEqual(len(reflected_tables), 1) + self._compare_table_columns(reflected_tables["manager"], Manager) + + def test_get_present_table(self): + run_sync(self.table_storage.reflect()) + table = run_sync(self.table_storage.get_table(tablename="manager")) + self._compare_table_columns(table, Manager) + + def test_get_unavailable_table(self): + run_sync(self.table_storage.reflect(exclude=["band"])) + # make sure only one table is present + self.assertEqual(len(self.table_storage.tables), 1) + table = run_sync(self.table_storage.get_table(tablename="band")) + # make sure the returned table is correct + self._compare_table_columns(table, Band) + # make sure the requested table has been added to the TableStorage + self.assertEqual(len(self.table_storage.tables), 2) + self.assertIsNotNone(self.table_storage.tables.get("band")) + + def test_get_schema_and_table_name(self): + tableNameDetail = self.table_storage._get_schema_and_table_name( + "music.manager" + ) + self.assertEqual(tableNameDetail.name, "manager") + self.assertEqual(tableNameDetail.schema, "music")