Skip to content

Commit facac58

Browse files
committed
added a shim so coroutines can be mocked in python 3.7 tests
1 parent 7c51cba commit facac58

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

tests/apps/migrations/auto/test_migration_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from tests.example_app.tables import Manager
1010
from tests.base import DBTestCase
11-
from tests.base import postgres_only
11+
from tests.base import postgres_only, set_mock_return_value
1212

1313

1414
class TestMigrationManager(DBTestCase):
@@ -202,7 +202,7 @@ def test_drop_column(self, get_migration_managers: MagicMock):
202202
self.assertEqual(response, [{"id": 1}])
203203

204204
# Reverse
205-
get_migration_managers.return_value = [manager_1]
205+
set_mock_return_value(get_migration_managers, [manager_1])
206206
asyncio.run(manager_2.run_backwards())
207207
response = self.run_sync("SELECT * FROM musician;")
208208
self.assertEqual(response, [{"id": 1, "name": ""}])
@@ -380,7 +380,7 @@ def test_drop_table(self, get_migration_managers: MagicMock):
380380
manager_2.drop_table(class_name="Musician", tablename="musician")
381381
asyncio.run(manager_2.run())
382382

383-
get_migration_managers.return_value = [manager_1]
383+
set_mock_return_value(get_migration_managers, [manager_1])
384384

385385
self.assertTrue(not self.table_exists("musician"))
386386

tests/base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22
import typing as t
33
from unittest import TestCase
4+
from unittest.mock import MagicMock
45

56
import pytest
67

@@ -23,6 +24,22 @@
2324
)
2425

2526

27+
def set_mock_return_value(magic_mock: MagicMock, return_value: t.Any):
28+
"""
29+
Python 3.8 has good support for mocking coroutines. For older versions,
30+
we must set the return value to be an awaitable explicitly.
31+
"""
32+
if magic_mock.__class__.__name__ == "AsyncMock":
33+
# Python 3.8 and above
34+
magic_mock.return_value = return_value
35+
else:
36+
37+
async def coroutine(*args, **kwargs):
38+
return return_value
39+
40+
magic_mock.return_value = coroutine()
41+
42+
2643
class DBTestCase(TestCase):
2744
"""
2845
Using raw SQL where possible, otherwise the tests are too reliant on other

0 commit comments

Comments
 (0)