Skip to content

Commit 3dfe15f

Browse files
committed
Make mock_client_session a proper Pytest fixture
1 parent afca2a9 commit 3dfe15f

File tree

4 files changed

+73
-65
lines changed

4 files changed

+73
-65
lines changed

tests/conftest.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,21 +73,29 @@ def read_file(self, state):
7373
return file.read()
7474

7575

76-
@asynccontextmanager
77-
async def mock_client_session():
78-
"""Context manager that replaces the global client_session with an AsyncMock instance.
76+
@pytest.fixture(scope="class")
77+
def mock_client_session_class(request):
78+
"""Class fixture to expose an AsyncMock to unittest.TestCase subclasses.
7979
80-
:Example:
80+
See: https://docs.pytest.org/en/5.4.1/unittest.html#mixing-pytest-fixtures-into-unittest-testcase-subclasses-using-marks
81+
"""
82+
83+
httputils.client_session = request.cls.mock_client_session = mock.AsyncMock()
84+
try:
85+
yield
86+
finally:
87+
del httputils.client_session
8188

82-
>>> async with mock_client_session() as mocked_client_session:
83-
>>> mocked_client_session.get = mocked_session_get
84-
>>> # test code...
8589

90+
@pytest.fixture
91+
async def mock_client_session():
92+
"""Context manager fixture that replaces the global client_session with an AsyncMock
93+
instance.
8694
"""
8795

88-
httputils.client_session = mocked_client_session = mock.AsyncMock()
96+
httputils.client_session = mock.AsyncMock()
8997
try:
90-
yield mocked_client_session
98+
yield httputils.client_session
9199
finally:
92100
del httputils.client_session
93101

tests/test_csbs.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytest
22

33
from app.services.location import csbs
4-
from tests.conftest import mock_client_session
54
from tests.conftest import mocked_session_get
65

76

@@ -28,10 +27,9 @@ def read_file(self):
2827

2928

3029
@pytest.mark.asyncio
31-
async def test_get_locations():
32-
async with mock_client_session() as mocked_client_session:
33-
mocked_client_session.get = mocked_session_get
34-
data = await csbs.get_locations()
30+
async def test_get_locations(mock_client_session):
31+
mock_client_session.get = mocked_session_get
32+
data = await csbs.get_locations()
3533

3634
assert isinstance(data, list)
3735

tests/test_jhu.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from app import location
66
from app.services.location import jhu
7-
from tests.conftest import mock_client_session
87
from tests.conftest import mocked_session_get
98
from tests.conftest import mocked_strptime_isoformat
109

@@ -13,17 +12,16 @@
1312

1413
@pytest.mark.asyncio
1514
@mock.patch("app.services.location.jhu.datetime")
16-
async def test_get_locations(mock_datetime):
15+
async def test_get_locations(mock_datetime, mock_client_session):
1716
mock_datetime.utcnow.return_value.isoformat.return_value = DATETIME_STRING
1817
mock_datetime.strptime.side_effect = mocked_strptime_isoformat
1918

20-
async with mock_client_session() as mocked_client_session:
21-
mocked_client_session.get = mocked_session_get
22-
output = await jhu.get_locations()
19+
mock_client_session.get = mocked_session_get
20+
output = await jhu.get_locations()
2321

24-
assert isinstance(output, list)
25-
assert isinstance(output[0], location.Location)
22+
assert isinstance(output, list)
23+
assert isinstance(output[0], location.Location)
2624

27-
# `jhu.get_locations()` creates id based on confirmed list
28-
location_confirmed = await jhu.get_category("confirmed")
29-
assert len(output) == len(location_confirmed["locations"])
25+
# `jhu.get_locations()` creates id based on confirmed list
26+
location_confirmed = await jhu.get_category("confirmed")
27+
assert len(output) == len(location_confirmed["locations"])

tests/test_routes.py

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,14 @@
66
import pytest
77
from async_asgi_testclient import TestClient
88

9-
from .conftest import mock_client_session
109
from .conftest import mocked_session_get
1110
from .conftest import mocked_strptime_isoformat
1211
from .test_jhu import DATETIME_STRING
1312
from app.main import APP
1413

1514

15+
@pytest.mark.usefixtures("mock_client_session_class")
1616
@pytest.mark.asyncio
17-
@mock.patch("app.services.location.jhu.datetime")
1817
class FlaskRoutesTest(unittest.TestCase):
1918
"""
2019
Need to mock some objects to control testing data locally
@@ -32,84 +31,91 @@ def read_file_v1(self, state):
3231
expected_json_output = file.read()
3332
return expected_json_output
3433

34+
@mock.patch("app.services.location.jhu.datetime")
3535
async def test_root_api(self, mock_datetime):
3636
"""Validate that / returns a 200 and is not a redirect."""
37-
async with mock_client_session() as mocked_client_session:
38-
mocked_client_session.get = mocked_session_get
39-
response = await self.asgi_client.get("/")
37+
self.mock_client_session.get = mocked_session_get
38+
39+
response = await self.asgi_client.get("/")
4040

4141
assert response.status_code == 200
4242
assert not response.is_redirect
4343

44+
@mock.patch("app.services.location.jhu.datetime")
4445
async def test_v1_confirmed(self, mock_datetime):
4546
mock_datetime.utcnow.return_value.isoformat.return_value = self.date
4647
mock_datetime.strptime.side_effect = mocked_strptime_isoformat
48+
self.mock_client_session.get = mocked_session_get
49+
4750
state = "confirmed"
4851
expected_json_output = self.read_file_v1(state=state)
49-
async with mock_client_session() as mocked_client_session:
50-
mocked_client_session.get = mocked_session_get
51-
response = await self.asgi_client.get("/{}".format(state))
52-
return_data = response.json()
52+
response = await self.asgi_client.get("/{}".format(state))
53+
return_data = response.json()
5354

5455
assert return_data == json.loads(expected_json_output)
5556

57+
@mock.patch("app.services.location.jhu.datetime")
5658
async def test_v1_deaths(self, mock_datetime):
5759
mock_datetime.utcnow.return_value.isoformat.return_value = self.date
5860
mock_datetime.strptime.side_effect = mocked_strptime_isoformat
61+
self.mock_client_session.get = mocked_session_get
62+
5963
state = "deaths"
6064
expected_json_output = self.read_file_v1(state=state)
61-
async with mock_client_session() as mocked_client_session:
62-
mocked_client_session.get = mocked_session_get
63-
response = await self.asgi_client.get("/{}".format(state))
64-
return_data = response.json()
65+
response = await self.asgi_client.get("/{}".format(state))
66+
return_data = response.json()
6567

6668
assert return_data == json.loads(expected_json_output)
6769

70+
@mock.patch("app.services.location.jhu.datetime")
6871
async def test_v1_recovered(self, mock_datetime):
6972
mock_datetime.utcnow.return_value.isoformat.return_value = self.date
7073
mock_datetime.strptime.side_effect = mocked_strptime_isoformat
74+
self.mock_client_session.get = mocked_session_get
75+
7176
state = "recovered"
7277
expected_json_output = self.read_file_v1(state=state)
73-
async with mock_client_session() as mocked_client_session:
74-
mocked_client_session.get = mocked_session_get
75-
response = await self.asgi_client.get("/{}".format(state))
76-
return_data = response.json()
78+
response = await self.asgi_client.get("/{}".format(state))
79+
return_data = response.json()
7780

7881
assert return_data == json.loads(expected_json_output)
7982

83+
@mock.patch("app.services.location.jhu.datetime")
8084
async def test_v1_all(self, mock_datetime):
8185
mock_datetime.utcnow.return_value.isoformat.return_value = self.date
8286
mock_datetime.strptime.side_effect = mocked_strptime_isoformat
87+
self.mock_client_session.get = mocked_session_get
88+
8389
state = "all"
8490
expected_json_output = self.read_file_v1(state=state)
85-
async with mock_client_session() as mocked_client_session:
86-
mocked_client_session.get = mocked_session_get
87-
response = await self.asgi_client.get("/{}".format(state))
88-
return_data = response.json()
91+
response = await self.asgi_client.get("/{}".format(state))
92+
return_data = response.json()
8993

9094
assert return_data == json.loads(expected_json_output)
9195

96+
@mock.patch("app.services.location.jhu.datetime")
9297
async def test_v2_latest(self, mock_datetime):
9398
mock_datetime.utcnow.return_value.isoformat.return_value = DATETIME_STRING
9499
mock_datetime.strptime.side_effect = mocked_strptime_isoformat
100+
self.mock_client_session.get = mocked_session_get
101+
95102
state = "latest"
96-
async with mock_client_session() as mocked_client_session:
97-
mocked_client_session.get = mocked_session_get
98-
response = await self.asgi_client.get(f"/v2/{state}")
99-
return_data = response.json()
103+
response = await self.asgi_client.get(f"/v2/{state}")
104+
return_data = response.json()
100105

101106
check_dict = {"latest": {"confirmed": 1940, "deaths": 1940, "recovered": 0}}
102107

103108
assert return_data == check_dict
104109

110+
@mock.patch("app.services.location.jhu.datetime")
105111
async def test_v2_locations(self, mock_datetime):
106112
mock_datetime.utcnow.return_value.isoformat.return_value = DATETIME_STRING
107113
mock_datetime.strptime.side_effect = mocked_strptime_isoformat
114+
self.mock_client_session.get = mocked_session_get
115+
108116
state = "locations"
109-
async with mock_client_session() as mocked_client_session:
110-
mocked_client_session.get = mocked_session_get
111-
response = await self.asgi_client.get("/v2/{}".format(state))
112-
return_data = response.json()
117+
response = await self.asgi_client.get("/v2/{}".format(state))
118+
return_data = response.json()
113119

114120
filepath = "tests/expected_output/v2_{state}.json".format(state=state)
115121
with open(filepath, "r") as file:
@@ -118,16 +124,16 @@ async def test_v2_locations(self, mock_datetime):
118124
# TODO: Why is this failing?
119125
# assert return_data == json.loads(expected_json_output)
120126

127+
@mock.patch("app.services.location.jhu.datetime")
121128
async def test_v2_locations_id(self, mock_datetime):
122129
mock_datetime.utcnow.return_value.isoformat.return_value = DATETIME_STRING
123130
mock_datetime.strptime.side_effect = mocked_strptime_isoformat
131+
self.mock_client_session.get = mocked_session_get
124132

125133
state = "locations"
126134
test_id = 1
127-
async with mock_client_session() as mocked_client_session:
128-
mocked_client_session.get = mocked_session_get
129-
response = await self.asgi_client.get("/v2/{}/{}".format(state, test_id))
130-
return_data = response.json()
135+
response = await self.asgi_client.get("/v2/{}/{}".format(state, test_id))
136+
return_data = response.json()
131137

132138
filepath = "tests/expected_output/v2_{state}_id_{test_id}.json".format(state=state, test_id=test_id)
133139
with open(filepath, "r") as file:
@@ -151,10 +157,9 @@ async def test_v2_locations_id(self, mock_datetime):
151157
({"source": "jhu", "country_code": "US"}, 404),
152158
],
153159
)
154-
async def test_locations_status_code(async_api_client, query_params, expected_status):
155-
async with mock_client_session() as mocked_client_session:
156-
mocked_client_session.get = mocked_session_get
157-
response = await async_api_client.get("/v2/locations", query_string=query_params)
160+
async def test_locations_status_code(async_api_client, query_params, expected_status, mock_client_session):
161+
mock_client_session.get = mocked_session_get
162+
response = await async_api_client.get("/v2/locations", query_string=query_params)
158163

159164
print(f"GET {response.url}\n{response}")
160165
print(f"\tjson:\n{pf(response.json())[:1000]}\n\t...")
@@ -173,10 +178,9 @@ async def test_locations_status_code(async_api_client, query_params, expected_st
173178
{"source": "jhu", "timelines": True},
174179
],
175180
)
176-
async def test_latest(async_api_client, query_params):
177-
async with mock_client_session() as mocked_client_session:
178-
mocked_client_session.get = mocked_session_get
179-
response = await async_api_client.get("/v2/latest", query_string=query_params)
181+
async def test_latest(async_api_client, query_params, mock_client_session):
182+
mock_client_session.get = mocked_session_get
183+
response = await async_api_client.get("/v2/latest", query_string=query_params)
180184

181185
print(f"GET {response.url}\n{response}")
182186

0 commit comments

Comments
 (0)